DtlsReliableHandshake.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. #if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
  2. #pragma warning disable
  3. using System;
  4. using System.Collections;
  5. using System.IO;
  6. using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities;
  7. using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Date;
  8. namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Tls
  9. {
  10. internal class DtlsReliableHandshake
  11. {
  12. private const int MAX_RECEIVE_AHEAD = 16;
  13. private const int MESSAGE_HEADER_LENGTH = 12;
  14. internal const int INITIAL_RESEND_MILLIS = 1000;
  15. private const int MAX_RESEND_MILLIS = 60000;
  16. /// <exception cref="IOException"/>
  17. internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput)
  18. {
  19. // TODO Support the possibility of a fragmented ClientHello datagram
  20. byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
  21. if (null == message || message.Length < MESSAGE_HEADER_LENGTH)
  22. return null;
  23. long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
  24. short msgType = TlsUtilities.ReadUint8(message, 0);
  25. if (HandshakeType.client_hello != msgType)
  26. return null;
  27. int length = TlsUtilities.ReadUint24(message, 1);
  28. if (message.Length != MESSAGE_HEADER_LENGTH + length)
  29. return null;
  30. // TODO Consider stricter HelloVerifyRequest-related checks
  31. //int messageSeq = TlsUtilities.ReadUint16(message, 4);
  32. //if (messageSeq > 1)
  33. // return null;
  34. int fragmentOffset = TlsUtilities.ReadUint24(message, 6);
  35. if (0 != fragmentOffset)
  36. return null;
  37. int fragmentLength = TlsUtilities.ReadUint24(message, 9);
  38. if (length != fragmentLength)
  39. return null;
  40. ClientHello clientHello = ClientHello.Parse(
  41. new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput);
  42. return new DtlsRequest(recordSeq, message, clientHello);
  43. }
  44. /// <exception cref="IOException"/>
  45. internal static void SendHelloVerifyRequest(DatagramSender sender, long recordSeq, byte[] cookie)
  46. {
  47. TlsUtilities.CheckUint8(cookie.Length);
  48. int length = 3 + cookie.Length;
  49. byte[] message = new byte[MESSAGE_HEADER_LENGTH + length];
  50. TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0);
  51. TlsUtilities.WriteUint24(length, message, 1);
  52. //TlsUtilities.WriteUint16(0, message, 4);
  53. //TlsUtilities.WriteUint24(0, message, 6);
  54. TlsUtilities.WriteUint24(length, message, 9);
  55. // HelloVerifyRequest fields
  56. TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0);
  57. TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2);
  58. DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message);
  59. }
  60. /*
  61. * No 'final' modifiers so that it works in earlier JDKs
  62. */
  63. private DtlsRecordLayer m_recordLayer;
  64. private Timeout m_handshakeTimeout;
  65. private TlsHandshakeHash m_handshakeHash;
  66. private IDictionary m_currentInboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable();
  67. private IDictionary m_previousInboundFlight = null;
  68. private IList m_outboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateArrayList();
  69. private int m_resendMillis = -1;
  70. private Timeout m_resendTimeout = null;
  71. private int m_next_send_seq = 0, m_next_receive_seq = 0;
  72. internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport, int timeoutMillis,
  73. DtlsRequest request)
  74. {
  75. this.m_recordLayer = transport;
  76. this.m_handshakeHash = new DeferredHash(context);
  77. this.m_handshakeTimeout = Timeout.ForWaitMillis(timeoutMillis);
  78. if (null != request)
  79. {
  80. this.m_resendMillis = INITIAL_RESEND_MILLIS;
  81. this.m_resendTimeout = new Timeout(m_resendMillis);
  82. long recordSeq = request.RecordSeq;
  83. int messageSeq = request.MessageSeq;
  84. byte[] message = request.Message;
  85. m_recordLayer.ResetAfterHelloVerifyRequestServer(recordSeq);
  86. // Simulate a previous flight consisting of the request ClientHello
  87. DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello,
  88. message.Length - MESSAGE_HEADER_LENGTH);
  89. m_currentInboundFlight[messageSeq] = reassembler;
  90. // We sent HelloVerifyRequest with (message) sequence number 0
  91. this.m_next_send_seq = 1;
  92. this.m_next_receive_seq = messageSeq + 1;
  93. m_handshakeHash.Update(message, 0, message.Length);
  94. }
  95. }
  96. internal void ResetAfterHelloVerifyRequestClient()
  97. {
  98. this.m_currentInboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable();
  99. this.m_previousInboundFlight = null;
  100. this.m_outboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateArrayList();
  101. this.m_resendMillis = -1;
  102. this.m_resendTimeout = null;
  103. // We're waiting for ServerHello, always with (message) sequence number 1
  104. this.m_next_receive_seq = 1;
  105. m_handshakeHash.Reset();
  106. }
  107. internal TlsHandshakeHash HandshakeHash
  108. {
  109. get { return m_handshakeHash; }
  110. }
  111. internal TlsHandshakeHash PrepareToFinish()
  112. {
  113. TlsHandshakeHash result = m_handshakeHash;
  114. this.m_handshakeHash = m_handshakeHash.StopTracking();
  115. return result;
  116. }
  117. /// <exception cref="IOException"/>
  118. internal void SendMessage(short msg_type, byte[] body)
  119. {
  120. TlsUtilities.CheckUint24(body.Length);
  121. if (null != m_resendTimeout)
  122. {
  123. CheckInboundFlight();
  124. this.m_resendMillis = -1;
  125. this.m_resendTimeout = null;
  126. m_outboundFlight.Clear();
  127. }
  128. Message message = new Message(m_next_send_seq++, msg_type, body);
  129. m_outboundFlight.Add(message);
  130. WriteMessage(message);
  131. UpdateHandshakeMessagesDigest(message);
  132. }
  133. /// <exception cref="IOException"/>
  134. internal byte[] ReceiveMessageBody(short msg_type)
  135. {
  136. Message message = ReceiveMessage();
  137. if (message.Type != msg_type)
  138. throw new TlsFatalAlert(AlertDescription.unexpected_message);
  139. return message.Body;
  140. }
  141. /// <exception cref="IOException"/>
  142. internal Message ReceiveMessage()
  143. {
  144. long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
  145. if (null == m_resendTimeout)
  146. {
  147. m_resendMillis = INITIAL_RESEND_MILLIS;
  148. m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
  149. PrepareInboundFlight(BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable());
  150. }
  151. byte[] buf = null;
  152. for (;;)
  153. {
  154. if (m_recordLayer.IsClosed)
  155. throw new TlsFatalAlert(AlertDescription.user_canceled);
  156. Message pending = GetPendingMessage();
  157. if (pending != null)
  158. return pending;
  159. if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
  160. throw new TlsTimeoutException("Handshake timed out");
  161. int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
  162. waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
  163. // NOTE: Ensure a finite wait, of at least 1ms
  164. if (waitMillis < 1)
  165. {
  166. waitMillis = 1;
  167. }
  168. int receiveLimit = m_recordLayer.GetReceiveLimit();
  169. if (buf == null || buf.Length < receiveLimit)
  170. {
  171. buf = new byte[receiveLimit];
  172. }
  173. int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
  174. if (received < 0)
  175. {
  176. ResendOutboundFlight();
  177. }
  178. else
  179. {
  180. ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
  181. }
  182. currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
  183. }
  184. }
  185. internal void Finish()
  186. {
  187. DtlsHandshakeRetransmit retransmit = null;
  188. if (null != m_resendTimeout)
  189. {
  190. CheckInboundFlight();
  191. }
  192. else
  193. {
  194. PrepareInboundFlight(null);
  195. if (m_previousInboundFlight != null)
  196. {
  197. /*
  198. * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
  199. * when in the FINISHED state, the node that transmits the last flight (the server in an
  200. * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
  201. * of the peer's last flight with a retransmit of the last flight.
  202. */
  203. retransmit = new Retransmit(this);
  204. }
  205. }
  206. m_recordLayer.HandshakeSuccessful(retransmit);
  207. }
  208. internal static int BackOff(int timeoutMillis)
  209. {
  210. /*
  211. * TODO[DTLS] implementations SHOULD back off handshake packet size during the
  212. * retransmit backoff.
  213. */
  214. return System.Math.Min(timeoutMillis * 2, MAX_RESEND_MILLIS);
  215. }
  216. /**
  217. * Check that there are no "extra" messages left in the current inbound flight
  218. */
  219. private void CheckInboundFlight()
  220. {
  221. foreach (int key in m_currentInboundFlight.Keys)
  222. {
  223. if (key >= m_next_receive_seq)
  224. {
  225. // TODO Should this be considered an error?
  226. }
  227. }
  228. }
  229. /// <exception cref="IOException"/>
  230. private Message GetPendingMessage()
  231. {
  232. DtlsReassembler next = (DtlsReassembler)m_currentInboundFlight[m_next_receive_seq];
  233. if (next != null)
  234. {
  235. byte[] body = next.GetBodyIfComplete();
  236. if (body != null)
  237. {
  238. m_previousInboundFlight = null;
  239. return UpdateHandshakeMessagesDigest(new Message(m_next_receive_seq++, next.MsgType, body));
  240. }
  241. }
  242. return null;
  243. }
  244. private void PrepareInboundFlight(IDictionary nextFlight)
  245. {
  246. ResetAll(m_currentInboundFlight);
  247. m_previousInboundFlight = m_currentInboundFlight;
  248. m_currentInboundFlight = nextFlight;
  249. }
  250. /// <exception cref="IOException"/>
  251. private void ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
  252. {
  253. bool checkPreviousFlight = false;
  254. while (len >= MESSAGE_HEADER_LENGTH)
  255. {
  256. int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
  257. int message_length = fragment_length + MESSAGE_HEADER_LENGTH;
  258. if (len < message_length)
  259. {
  260. // NOTE: Truncated message - ignore it
  261. break;
  262. }
  263. int length = TlsUtilities.ReadUint24(buf, off + 1);
  264. int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
  265. if (fragment_offset + fragment_length > length)
  266. {
  267. // NOTE: Malformed fragment - ignore it and the rest of the record
  268. break;
  269. }
  270. /*
  271. * NOTE: This very simple epoch check will only work until we want to support
  272. * renegotiation (and we're not likely to do that anyway).
  273. */
  274. short msg_type = TlsUtilities.ReadUint8(buf, off + 0);
  275. int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
  276. if (epoch != expectedEpoch)
  277. break;
  278. int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
  279. if (message_seq >= (m_next_receive_seq + windowSize))
  280. {
  281. // NOTE: Too far ahead - ignore
  282. }
  283. else if (message_seq >= m_next_receive_seq)
  284. {
  285. DtlsReassembler reassembler = (DtlsReassembler)m_currentInboundFlight[message_seq];
  286. if (reassembler == null)
  287. {
  288. reassembler = new DtlsReassembler(msg_type, length);
  289. m_currentInboundFlight[message_seq] = reassembler;
  290. }
  291. reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset,
  292. fragment_length);
  293. }
  294. else if (m_previousInboundFlight != null)
  295. {
  296. /*
  297. * NOTE: If we receive the previous flight of incoming messages in full again,
  298. * retransmit our last flight
  299. */
  300. DtlsReassembler reassembler = (DtlsReassembler)m_previousInboundFlight[message_seq];
  301. if (reassembler != null)
  302. {
  303. reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH,
  304. fragment_offset, fragment_length);
  305. checkPreviousFlight = true;
  306. }
  307. }
  308. off += message_length;
  309. len -= message_length;
  310. }
  311. if (checkPreviousFlight && CheckAll(m_previousInboundFlight))
  312. {
  313. ResendOutboundFlight();
  314. ResetAll(m_previousInboundFlight);
  315. }
  316. }
  317. /// <exception cref="IOException"/>
  318. private void ResendOutboundFlight()
  319. {
  320. m_recordLayer.ResetWriteEpoch();
  321. foreach (Message message in m_outboundFlight)
  322. {
  323. WriteMessage(message);
  324. }
  325. m_resendMillis = BackOff(m_resendMillis);
  326. m_resendTimeout = new Timeout(m_resendMillis);
  327. }
  328. /// <exception cref="IOException"/>
  329. private Message UpdateHandshakeMessagesDigest(Message message)
  330. {
  331. short msg_type = message.Type;
  332. switch (msg_type)
  333. {
  334. case HandshakeType.hello_request:
  335. case HandshakeType.hello_verify_request:
  336. case HandshakeType.key_update:
  337. break;
  338. // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
  339. case HandshakeType.new_session_ticket:
  340. default:
  341. {
  342. byte[] body = message.Body;
  343. byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
  344. TlsUtilities.WriteUint8(msg_type, buf, 0);
  345. TlsUtilities.WriteUint24(body.Length, buf, 1);
  346. TlsUtilities.WriteUint16(message.Seq, buf, 4);
  347. TlsUtilities.WriteUint24(0, buf, 6);
  348. TlsUtilities.WriteUint24(body.Length, buf, 9);
  349. m_handshakeHash.Update(buf, 0, buf.Length);
  350. m_handshakeHash.Update(body, 0, body.Length);
  351. break;
  352. }
  353. }
  354. return message;
  355. }
  356. /// <exception cref="IOException"/>
  357. private void WriteMessage(Message message)
  358. {
  359. int sendLimit = m_recordLayer.GetSendLimit();
  360. int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH;
  361. // TODO Support a higher minimum fragment size?
  362. if (fragmentLimit < 1)
  363. {
  364. // TODO Should we be throwing an exception here?
  365. throw new TlsFatalAlert(AlertDescription.internal_error);
  366. }
  367. int length = message.Body.Length;
  368. // NOTE: Must still send a fragment if body is empty
  369. int fragment_offset = 0;
  370. do
  371. {
  372. int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit);
  373. WriteHandshakeFragment(message, fragment_offset, fragment_length);
  374. fragment_offset += fragment_length;
  375. }
  376. while (fragment_offset < length);
  377. }
  378. /// <exception cref="IOException"/>
  379. private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
  380. {
  381. RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length);
  382. TlsUtilities.WriteUint8(message.Type, fragment);
  383. TlsUtilities.WriteUint24(message.Body.Length, fragment);
  384. TlsUtilities.WriteUint16(message.Seq, fragment);
  385. TlsUtilities.WriteUint24(fragment_offset, fragment);
  386. TlsUtilities.WriteUint24(fragment_length, fragment);
  387. fragment.Write(message.Body, fragment_offset, fragment_length);
  388. fragment.SendToRecordLayer(m_recordLayer);
  389. }
  390. private static bool CheckAll(IDictionary inboundFlight)
  391. {
  392. foreach (DtlsReassembler r in inboundFlight.Values)
  393. {
  394. if (r.GetBodyIfComplete() == null)
  395. return false;
  396. }
  397. return true;
  398. }
  399. private static void ResetAll(IDictionary inboundFlight)
  400. {
  401. foreach (DtlsReassembler r in inboundFlight.Values)
  402. {
  403. r.Reset();
  404. }
  405. }
  406. internal class Message
  407. {
  408. private readonly int m_message_seq;
  409. private readonly short m_msg_type;
  410. private readonly byte[] m_body;
  411. internal Message(int message_seq, short msg_type, byte[] body)
  412. {
  413. this.m_message_seq = message_seq;
  414. this.m_msg_type = msg_type;
  415. this.m_body = body;
  416. }
  417. public int Seq
  418. {
  419. get { return m_message_seq; }
  420. }
  421. public short Type
  422. {
  423. get { return m_msg_type; }
  424. }
  425. public byte[] Body
  426. {
  427. get { return m_body; }
  428. }
  429. }
  430. internal class RecordLayerBuffer
  431. : MemoryStream
  432. {
  433. internal RecordLayerBuffer(int size)
  434. : base(size)
  435. {
  436. }
  437. internal void SendToRecordLayer(DtlsRecordLayer recordLayer)
  438. {
  439. #if PORTABLE || NETFX_CORE
  440. byte[] buf = ToArray();
  441. int bufLen = buf.Length;
  442. #else
  443. byte[] buf = GetBuffer();
  444. int bufLen = (int)Length;
  445. #endif
  446. recordLayer.Send(buf, 0, bufLen);
  447. BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(this);
  448. }
  449. }
  450. internal class Retransmit
  451. : DtlsHandshakeRetransmit
  452. {
  453. private readonly DtlsReliableHandshake m_outer;
  454. internal Retransmit(DtlsReliableHandshake outer)
  455. {
  456. this.m_outer = outer;
  457. }
  458. public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
  459. {
  460. m_outer.ProcessRecord(0, epoch, buf, off, len);
  461. }
  462. }
  463. }
  464. }
  465. #pragma warning restore
  466. #endif