1 /*
2  * hunt-proton: AMQP Protocol library for D programming language.
3  *
4  * Copyright (C) 2018-2019 HuntLabs
5  *
6  * Website: https://www.huntlabs.net/
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module hunt.proton.engine.impl.SaslFrameParser;
13 
14 import  hunt.proton.engine.impl.AmqpHeader;
15 
16 import hunt.io.ByteBuffer;
17 
18 import hunt.proton.amqp.Binary;
19 import hunt.proton.amqp.security.SaslFrameBody;
20 import hunt.proton.codec.ByteBufferDecoder;
21 import hunt.proton.codec.DecodeException;
22 import hunt.proton.engine.TransportException;
23 import hunt.proton.engine.impl.SaslFrameHandler;
24 import hunt.proton.engine.impl.TransportImpl;
25 import hunt.proton.engine.impl.SaslImpl;
26 import hunt.proton.codec.DecodeException;
27 import hunt.proton.engine.impl.ProtocolTracer;
28 import hunt.io.BufferUtils;
29 import hunt.String;
30 import hunt.logging;
31 
32 class SaslFrameParser
33 {
34     private static string HEADER_DESCRIPTION = "SASL";
35 
36     private SaslFrameHandler _sasl;
37 
38     enum State
39     {
40         HEADER0,
41         HEADER1,
42         HEADER2,
43         HEADER3,
44         HEADER4,
45         HEADER5,
46         HEADER6,
47         HEADER7,
48         SIZE_0,
49         SIZE_1,
50         SIZE_2,
51         SIZE_3,
52         PRE_PARSE,
53         BUFFERING,
54         PARSING,
55         ERROR
56     }
57 
58     private State _state = State.HEADER0;
59     private int _size;
60 
61     private ByteBuffer _buffer;
62 
63     private ByteBufferDecoder _decoder;
64     private int _frameSizeLimit;
65     private TransportImpl _transport;
66 
67     this(SaslFrameHandler sasl, ByteBufferDecoder decoder, int frameSizeLimit, TransportImpl transport)
68     {
69         _sasl = sasl;
70         _decoder = decoder;
71         _frameSizeLimit = frameSizeLimit;
72         _transport = transport;
73     }
74 
75     /**
76      * Parse the provided SASL input and call my SASL frame handler with the result
77      */
78     public void input(ByteBuffer input)
79     {
80         TransportException frameParsingError = null;
81         int size = _size;
82         State state = _state;
83         ByteBuffer oldIn = null;
84 
85         while(input.hasRemaining() && state != State.ERROR && !_sasl.isDone())
86         {
87             switch(state)
88             {
89                 case State.HEADER0:
90                     if(input.hasRemaining())
91                     {
92                         byte c = input.get();
93                         if(c != AmqpHeader.SASL_HEADER[0])
94                         {
95                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
96                             state = State.ERROR;
97                             break;
98                         }
99                         state = State.HEADER1;
100                     }
101                     else
102                     {
103                         break;
104                     }
105                     goto case;
106                 case State.HEADER1:
107                     if(input.hasRemaining())
108                     {
109                         byte c = input.get();
110                         if(c != AmqpHeader.SASL_HEADER[1])
111                         {
112                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
113                             state = State.ERROR;
114                             break;
115                         }
116                         state = State.HEADER2;
117                     }
118                     else
119                     {
120                         break;
121                     }
122                     goto case;
123                 case State.HEADER2:
124                     if(input.hasRemaining())
125                     {
126                         byte c = input.get();
127                         if(c != AmqpHeader.SASL_HEADER[2])
128                         {
129                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
130                             state = State.ERROR;
131                             break;
132                         }
133                         state = State.HEADER3;
134                     }
135                     else
136                     {
137                         break;
138                     }
139                     goto case;
140                 case State.HEADER3:
141                     if(input.hasRemaining())
142                     {
143                         byte c = input.get();
144                         if(c != AmqpHeader.SASL_HEADER[3])
145                         {
146                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
147                             state = State.ERROR;
148                             break;
149                         }
150                         state = State.HEADER4;
151                     }
152                     else
153                     {
154                         break;
155                     }
156                     goto case;
157                 case State.HEADER4:
158                     if(input.hasRemaining())
159                     {
160                         byte c = input.get();
161                         if(c != AmqpHeader.SASL_HEADER[4])
162                         {
163                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
164                             state = State.ERROR;
165                             break;
166                         }
167                         state = State.HEADER5;
168                     }
169                     else
170                     {
171                         break;
172                     }
173                     goto case;
174                 case State.HEADER5:
175                     if(input.hasRemaining())
176                     {
177                         byte c = input.get();
178                         if(c != AmqpHeader.SASL_HEADER[5])
179                         {
180                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
181                             state = State.ERROR;
182                             break;
183                         }
184                         state = State.HEADER6;
185                     }
186                     else
187                     {
188                         break;
189                     }
190                     goto case;
191                 case State.HEADER6:
192                     if(input.hasRemaining())
193                     {
194                         byte c = input.get();
195                         if(c != AmqpHeader.SASL_HEADER[6])
196                         {
197                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
198                             state = State.ERROR;
199                             break;
200                         }
201                         state = State.HEADER7;
202                     }
203                     else
204                     {
205                         break;
206                     }
207                     goto case;
208                 case State.HEADER7:
209                     if(input.hasRemaining())
210                     {
211                         byte c = input.get();
212                         if(c != AmqpHeader.SASL_HEADER[7])
213                         {
214                             frameParsingError = new TransportException("AMQP SASL header mismatch value %x, expecting %x. In state: %s");
215                             state = State.ERROR;
216                             break;
217                         }
218 
219                         logHeader();
220 
221                         state = State.SIZE_0;
222                     }
223                     else
224                     {
225                         break;
226                     }
227                     goto case;
228                 case State.SIZE_0:
229                     if(!input.hasRemaining())
230                     {
231                         break;
232                     }
233 
234                     if(input.remaining() >= 4)
235                     {
236                         size = input.getInt();
237                         state = State.PRE_PARSE;
238                         break;
239                     }
240                     else
241                     {
242                         size = (input.get() << 24) & 0xFF000000;
243                         if(!input.hasRemaining())
244                         {
245                             state = State.SIZE_1;
246                             break;
247                         }
248                     }
249                     goto case;
250                 case State.SIZE_1:
251                     size |= (input.get() << 16) & 0xFF0000;
252                     if(!input.hasRemaining())
253                     {
254                         state = State.SIZE_2;
255                         break;
256                     }
257                     goto case;
258                 case State.SIZE_2:
259                     size |= (input.get() << 8) & 0xFF00;
260                     if(!input.hasRemaining())
261                     {
262                         state = State.SIZE_3;
263                         break;
264                     }
265                     goto case;
266                 case State.SIZE_3:
267                     size |= input.get() & 0xFF;
268                     state = State.PRE_PARSE;
269                     goto case;
270                 case State.PRE_PARSE:
271                     if(size < 8)
272                     {
273                         frameParsingError = new TransportException(
274                                 "specified frame size %d smaller than minimum SASL frame header size 8");
275                         state = State.ERROR;
276                         break;
277                     }
278 
279                     if (size > _frameSizeLimit)
280                     {
281                         frameParsingError = new TransportException(
282                                 "specified frame size %d larger than maximum SASL frame size %d");
283                         state = State.ERROR;
284                         break;
285                     }
286 
287                     if(input.remaining() < size-4)
288                     {
289                         _buffer = BufferUtils.allocate(size-4);
290                         _buffer.put(input);
291                         state = State.BUFFERING;
292                         break;
293                     }
294                     goto case;
295                 case State.BUFFERING:
296                     if(_buffer !is null)
297                     {
298                         if(input.remaining() < _buffer.remaining())
299                         {
300                             _buffer.put(input);
301                             break;
302                         }
303                         else
304                         {
305                             ByteBuffer dup = input.duplicate();
306                             dup.limit(dup.position()+_buffer.remaining());
307                             input.position(input.position()+_buffer.remaining());
308                             _buffer.put(dup);
309                             oldIn = input;
310                             _buffer.flip();
311                             input = _buffer;
312                             state = State.PARSING;
313                         }
314                     }
315                     goto case;
316                 case State.PARSING:
317 
318                     int dataOffset = (input.get() << 2) & 0x3FF;
319 
320                     if(dataOffset < 8)
321                     {
322                         frameParsingError = new TransportException("specified frame data offset %d smaller than minimum frame header size %d");
323                         state = State.ERROR;
324                         break;
325                     }
326                     else if(dataOffset > size)
327                     {
328                         frameParsingError = new TransportException("specified frame data offset %d larger than the frame size %d");
329                         state = State.ERROR;
330                         break;
331                     }
332 
333                     // type
334 
335                     int type = input.get() & 0xFF;
336                     // SASL frame has no type-specific content in the frame header, so we skip next two bytes
337                     input.get();
338                     input.get();
339 
340                     if(type != SaslImpl.SASL_FRAME_TYPE)
341                     {
342                         frameParsingError = new TransportException("unknown frame type: %d");
343                         state = State.ERROR;
344                         break;
345                     }
346 
347                     if(dataOffset!=8)
348                     {
349                         input.position(input.position()+dataOffset-8);
350                     }
351 
352                     // oldIn null iff not working on duplicated buffer
353                     if(oldIn is null)
354                     {
355                         oldIn = input;
356                         input = input.duplicate();
357                         int endPos = input.position() + size - dataOffset;
358                         input.limit(endPos);
359                         oldIn.position(endPos);
360 
361                     }
362 
363                     try
364                     {
365                      //   logInfo("vvvvvv %s",input.getRemaining());
366                         _decoder.setByteBuffer(input);
367                         Object val = _decoder.readObject();
368 
369                         Binary payload;
370 
371                         if(input.hasRemaining())
372                         {
373                             byte[] payloadBytes = new byte[input.remaining()];
374                             input.get(payloadBytes);
375                             payload = new Binary(payloadBytes);
376                         }
377                         else
378                         {
379                             payload = null;
380                         }
381 
382                         SaslFrameBody frameBody = cast(SaslFrameBody)val;
383                         if(val !is null)
384                         {
385                             _sasl.handle(frameBody, payload);
386 
387                             reset();
388                             input = oldIn;
389                             oldIn = null;
390                             _buffer = null;
391                             state = State.SIZE_0;
392                         }
393                         else
394                         {
395                             state = State.ERROR;
396                             frameParsingError = new TransportException("Unexpected frame type encountered");
397                         }
398                     }
399                     catch (DecodeException ex)
400                     {
401                         state = State.ERROR;
402                         frameParsingError = new TransportException(ex);
403                     }
404                     break;
405                 case State.ERROR:
406                     break;
407                     // do nothing
408                 default:
409                     break;
410             }
411 
412         }
413 
414         _state = state;
415         _size = size;
416 
417         if(_state == State.ERROR)
418         {
419             if(frameParsingError !is null)
420             {
421                 throw frameParsingError;
422             }
423             else
424             {
425                 throw new TransportException("Unable to parse, probably because of a previous error");
426             }
427         }
428     }
429 
430     private void reset()
431     {
432         _size = 0;
433         _state = State.SIZE_0;
434     }
435 
436     private void logHeader()
437     {
438         if (_transport.isFrameTracingEnabled())
439         {
440             _transport.log(TransportImpl.INCOMING, new String(HEADER_DESCRIPTION));
441 
442             ProtocolTracer tracer = _transport.getProtocolTracer();
443             if (tracer !is null)
444             {
445                 tracer.receivedHeader(HEADER_DESCRIPTION);
446             }
447         }
448     }
449 }