DtlsReliableHandshake.cs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. #if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
  2. #pragma warning disable
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using Best.HTTP.SecureProtocol.Org.BouncyCastle.Utilities.Date;
  7. namespace Best.HTTP.SecureProtocol.Org.BouncyCastle.Tls
  8. {
  9. internal class DtlsReliableHandshake
  10. {
  11. private const int MAX_RECEIVE_AHEAD = 16;
  12. private const int MESSAGE_HEADER_LENGTH = 12;
  13. internal const int INITIAL_RESEND_MILLIS = 1000;
  14. private const int MAX_RESEND_MILLIS = 60000;
  15. /// <exception cref="IOException"/>
  16. internal static DtlsRequest ReadClientRequest(byte[] data, int dataOff, int dataLen, Stream dtlsOutput)
  17. {
  18. // TODO Support the possibility of a fragmented ClientHello datagram
  19. byte[] message = DtlsRecordLayer.ReceiveClientHelloRecord(data, dataOff, dataLen);
  20. if (null == message || message.Length < MESSAGE_HEADER_LENGTH)
  21. return null;
  22. long recordSeq = TlsUtilities.ReadUint48(data, dataOff + 5);
  23. short msgType = TlsUtilities.ReadUint8(message, 0);
  24. if (HandshakeType.client_hello != msgType)
  25. return null;
  26. int length = TlsUtilities.ReadUint24(message, 1);
  27. if (message.Length != MESSAGE_HEADER_LENGTH + length)
  28. return null;
  29. // TODO Consider stricter HelloVerifyRequest-related checks
  30. //int messageSeq = TlsUtilities.ReadUint16(message, 4);
  31. //if (messageSeq > 1)
  32. // return null;
  33. int fragmentOffset = TlsUtilities.ReadUint24(message, 6);
  34. if (0 != fragmentOffset)
  35. return null;
  36. int fragmentLength = TlsUtilities.ReadUint24(message, 9);
  37. if (length != fragmentLength)
  38. return null;
  39. ClientHello clientHello = ClientHello.Parse(
  40. new MemoryStream(message, MESSAGE_HEADER_LENGTH, length, false), dtlsOutput);
  41. return new DtlsRequest(recordSeq, message, clientHello);
  42. }
  43. /// <exception cref="IOException"/>
  44. internal static void SendHelloVerifyRequest(DatagramSender sender, long recordSeq, byte[] cookie)
  45. {
  46. TlsUtilities.CheckUint8(cookie.Length);
  47. int length = 3 + cookie.Length;
  48. byte[] message = new byte[MESSAGE_HEADER_LENGTH + length];
  49. TlsUtilities.WriteUint8(HandshakeType.hello_verify_request, message, 0);
  50. TlsUtilities.WriteUint24(length, message, 1);
  51. //TlsUtilities.WriteUint16(0, message, 4);
  52. //TlsUtilities.WriteUint24(0, message, 6);
  53. TlsUtilities.WriteUint24(length, message, 9);
  54. // HelloVerifyRequest fields
  55. TlsUtilities.WriteVersion(ProtocolVersion.DTLSv10, message, MESSAGE_HEADER_LENGTH + 0);
  56. TlsUtilities.WriteOpaque8(cookie, message, MESSAGE_HEADER_LENGTH + 2);
  57. DtlsRecordLayer.SendHelloVerifyRequestRecord(sender, recordSeq, message);
  58. }
  59. /*
  60. * No 'final' modifiers so that it works in earlier JDKs
  61. */
  62. private DtlsRecordLayer m_recordLayer;
  63. private Timeout m_handshakeTimeout;
  64. private TlsHandshakeHash m_handshakeHash;
  65. private IDictionary<int, DtlsReassembler> m_currentInboundFlight = new Dictionary<int, DtlsReassembler>();
  66. private IDictionary<int, DtlsReassembler> m_previousInboundFlight = null;
  67. private IList<Message> m_outboundFlight = new List<Message>();
  68. private int m_resendMillis = -1;
  69. private Timeout m_resendTimeout = null;
  70. private int m_next_send_seq = 0, m_next_receive_seq = 0;
  71. internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport, int timeoutMillis,
  72. DtlsRequest request)
  73. {
  74. this.m_recordLayer = transport;
  75. this.m_handshakeHash = new DeferredHash(context);
  76. this.m_handshakeTimeout = Timeout.ForWaitMillis(timeoutMillis);
  77. if (null != request)
  78. {
  79. this.m_resendMillis = INITIAL_RESEND_MILLIS;
  80. this.m_resendTimeout = new Timeout(m_resendMillis);
  81. long recordSeq = request.RecordSeq;
  82. int messageSeq = request.MessageSeq;
  83. byte[] message = request.Message;
  84. m_recordLayer.ResetAfterHelloVerifyRequestServer(recordSeq);
  85. // Simulate a previous flight consisting of the request ClientHello
  86. DtlsReassembler reassembler = new DtlsReassembler(HandshakeType.client_hello,
  87. message.Length - MESSAGE_HEADER_LENGTH);
  88. m_currentInboundFlight[messageSeq] = reassembler;
  89. // We sent HelloVerifyRequest with (message) sequence number 0
  90. this.m_next_send_seq = 1;
  91. this.m_next_receive_seq = messageSeq + 1;
  92. m_handshakeHash.Update(message, 0, message.Length);
  93. }
  94. }
  95. internal void ResetAfterHelloVerifyRequestClient()
  96. {
  97. this.m_currentInboundFlight = new Dictionary<int, DtlsReassembler>();
  98. this.m_previousInboundFlight = null;
  99. this.m_outboundFlight = new List<Message>();
  100. this.m_resendMillis = -1;
  101. this.m_resendTimeout = null;
  102. // We're waiting for ServerHello, always with (message) sequence number 1
  103. this.m_next_receive_seq = 1;
  104. m_handshakeHash.Reset();
  105. }
  106. internal TlsHandshakeHash HandshakeHash
  107. {
  108. get { return m_handshakeHash; }
  109. }
  110. internal void PrepareToFinish()
  111. {
  112. m_handshakeHash.StopTracking();
  113. }
  114. /// <exception cref="IOException"/>
  115. internal void SendMessage(short msg_type, byte[] body)
  116. {
  117. TlsUtilities.CheckUint24(body.Length);
  118. if (null != m_resendTimeout)
  119. {
  120. CheckInboundFlight();
  121. this.m_resendMillis = -1;
  122. this.m_resendTimeout = null;
  123. m_outboundFlight.Clear();
  124. }
  125. Message message = new Message(m_next_send_seq++, msg_type, body);
  126. m_outboundFlight.Add(message);
  127. WriteMessage(message);
  128. UpdateHandshakeMessagesDigest(message);
  129. }
  130. /// <exception cref="IOException"/>
  131. internal Message ReceiveMessage()
  132. {
  133. Message message = ImplReceiveMessage();
  134. UpdateHandshakeMessagesDigest(message);
  135. return message;
  136. }
  137. /// <exception cref="IOException"/>
  138. internal byte[] ReceiveMessageBody(short msg_type)
  139. {
  140. Message message = ImplReceiveMessage();
  141. if (message.Type != msg_type)
  142. throw new TlsFatalAlert(AlertDescription.unexpected_message);
  143. UpdateHandshakeMessagesDigest(message);
  144. return message.Body;
  145. }
  146. /// <exception cref="IOException"/>
  147. internal Message ReceiveMessageDelayedDigest(short msg_type)
  148. {
  149. Message message = ImplReceiveMessage();
  150. if (message.Type != msg_type)
  151. throw new TlsFatalAlert(AlertDescription.unexpected_message);
  152. return message;
  153. }
  154. /// <exception cref="IOException"/>
  155. internal void UpdateHandshakeMessagesDigest(Message message)
  156. {
  157. short msg_type = message.Type;
  158. switch (msg_type)
  159. {
  160. case HandshakeType.hello_request:
  161. case HandshakeType.hello_verify_request:
  162. case HandshakeType.key_update:
  163. break;
  164. // TODO[dtls13] Not included in the transcript for (D)TLS 1.3+
  165. case HandshakeType.new_session_ticket:
  166. default:
  167. {
  168. byte[] body = message.Body;
  169. byte[] buf = new byte[MESSAGE_HEADER_LENGTH];
  170. TlsUtilities.WriteUint8(msg_type, buf, 0);
  171. TlsUtilities.WriteUint24(body.Length, buf, 1);
  172. TlsUtilities.WriteUint16(message.Seq, buf, 4);
  173. TlsUtilities.WriteUint24(0, buf, 6);
  174. TlsUtilities.WriteUint24(body.Length, buf, 9);
  175. m_handshakeHash.Update(buf, 0, buf.Length);
  176. m_handshakeHash.Update(body, 0, body.Length);
  177. break;
  178. }
  179. }
  180. }
  181. internal void Finish()
  182. {
  183. DtlsHandshakeRetransmit retransmit = null;
  184. if (null != m_resendTimeout)
  185. {
  186. CheckInboundFlight();
  187. }
  188. else
  189. {
  190. PrepareInboundFlight(null);
  191. if (m_previousInboundFlight != null)
  192. {
  193. /*
  194. * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
  195. * when in the FINISHED state, the node that transmits the last flight (the server in an
  196. * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
  197. * of the peer's last flight with a retransmit of the last flight.
  198. */
  199. retransmit = new Retransmit(this);
  200. }
  201. }
  202. m_recordLayer.HandshakeSuccessful(retransmit);
  203. }
  204. internal static int BackOff(int timeoutMillis)
  205. {
  206. /*
  207. * TODO[DTLS] implementations SHOULD back off handshake packet size during the
  208. * retransmit backoff.
  209. */
  210. return System.Math.Min(timeoutMillis * 2, MAX_RESEND_MILLIS);
  211. }
  212. /**
  213. * Check that there are no "extra" messages left in the current inbound flight
  214. */
  215. private void CheckInboundFlight()
  216. {
  217. foreach (int key in m_currentInboundFlight.Keys)
  218. {
  219. if (key >= m_next_receive_seq)
  220. {
  221. // TODO Should this be considered an error?
  222. }
  223. }
  224. }
  225. /// <exception cref="IOException"/>
  226. private Message GetPendingMessage()
  227. {
  228. if (m_currentInboundFlight.TryGetValue(m_next_receive_seq, out var next))
  229. {
  230. byte[] body = next.GetBodyIfComplete();
  231. if (body != null)
  232. {
  233. m_previousInboundFlight = null;
  234. return new Message(m_next_receive_seq++, next.MsgType, body);
  235. }
  236. }
  237. return null;
  238. }
  239. /// <exception cref="IOException"/>
  240. private Message ImplReceiveMessage()
  241. {
  242. long currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
  243. if (null == m_resendTimeout)
  244. {
  245. m_resendMillis = INITIAL_RESEND_MILLIS;
  246. m_resendTimeout = new Timeout(m_resendMillis, currentTimeMillis);
  247. PrepareInboundFlight(new Dictionary<int, DtlsReassembler>());
  248. }
  249. byte[] buf = null;
  250. for (;;)
  251. {
  252. if (m_recordLayer.IsClosed)
  253. throw new TlsFatalAlert(AlertDescription.user_canceled);
  254. Message pending = GetPendingMessage();
  255. if (pending != null)
  256. return pending;
  257. if (Timeout.HasExpired(m_handshakeTimeout, currentTimeMillis))
  258. throw new TlsTimeoutException("Handshake timed out");
  259. int waitMillis = Timeout.GetWaitMillis(m_handshakeTimeout, currentTimeMillis);
  260. waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_resendTimeout, currentTimeMillis);
  261. // NOTE: Ensure a finite wait, of at least 1ms
  262. if (waitMillis < 1)
  263. {
  264. waitMillis = 1;
  265. }
  266. int receiveLimit = m_recordLayer.GetReceiveLimit();
  267. if (buf == null || buf.Length < receiveLimit)
  268. {
  269. buf = new byte[receiveLimit];
  270. }
  271. int received = m_recordLayer.Receive(buf, 0, receiveLimit, waitMillis);
  272. if (received < 0)
  273. {
  274. ResendOutboundFlight();
  275. }
  276. else
  277. {
  278. ProcessRecord(MAX_RECEIVE_AHEAD, m_recordLayer.ReadEpoch, buf, 0, received);
  279. }
  280. currentTimeMillis = DateTimeUtilities.CurrentUnixMs();
  281. }
  282. }
  283. private void PrepareInboundFlight(IDictionary<int, DtlsReassembler> nextFlight)
  284. {
  285. ResetAll(m_currentInboundFlight);
  286. m_previousInboundFlight = m_currentInboundFlight;
  287. m_currentInboundFlight = nextFlight;
  288. }
  289. /// <exception cref="IOException"/>
  290. private void ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
  291. {
  292. bool checkPreviousFlight = false;
  293. while (len >= MESSAGE_HEADER_LENGTH)
  294. {
  295. int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
  296. int message_length = fragment_length + MESSAGE_HEADER_LENGTH;
  297. if (len < message_length)
  298. {
  299. // NOTE: Truncated message - ignore it
  300. break;
  301. }
  302. int length = TlsUtilities.ReadUint24(buf, off + 1);
  303. int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
  304. if (fragment_offset + fragment_length > length)
  305. {
  306. // NOTE: Malformed fragment - ignore it and the rest of the record
  307. break;
  308. }
  309. /*
  310. * NOTE: This very simple epoch check will only work until we want to support
  311. * renegotiation (and we're not likely to do that anyway).
  312. */
  313. short msg_type = TlsUtilities.ReadUint8(buf, off + 0);
  314. int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
  315. if (epoch != expectedEpoch)
  316. break;
  317. int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
  318. if (message_seq >= (m_next_receive_seq + windowSize))
  319. {
  320. // NOTE: Too far ahead - ignore
  321. }
  322. else if (message_seq >= m_next_receive_seq)
  323. {
  324. if (!m_currentInboundFlight.TryGetValue(message_seq, out var reassembler))
  325. {
  326. reassembler = new DtlsReassembler(msg_type, length);
  327. m_currentInboundFlight[message_seq] = reassembler;
  328. }
  329. reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH, fragment_offset,
  330. fragment_length);
  331. }
  332. else if (m_previousInboundFlight != null)
  333. {
  334. /*
  335. * NOTE: If we receive the previous flight of incoming messages in full again,
  336. * retransmit our last flight
  337. */
  338. if (m_previousInboundFlight.TryGetValue(message_seq, out var reassembler))
  339. {
  340. reassembler.ContributeFragment(msg_type, length, buf, off + MESSAGE_HEADER_LENGTH,
  341. fragment_offset, fragment_length);
  342. checkPreviousFlight = true;
  343. }
  344. }
  345. off += message_length;
  346. len -= message_length;
  347. }
  348. if (checkPreviousFlight && CheckAll(m_previousInboundFlight))
  349. {
  350. ResendOutboundFlight();
  351. ResetAll(m_previousInboundFlight);
  352. }
  353. }
  354. /// <exception cref="IOException"/>
  355. private void ResendOutboundFlight()
  356. {
  357. m_recordLayer.ResetWriteEpoch();
  358. foreach (Message message in m_outboundFlight)
  359. {
  360. WriteMessage(message);
  361. }
  362. m_resendMillis = BackOff(m_resendMillis);
  363. m_resendTimeout = new Timeout(m_resendMillis);
  364. }
  365. /// <exception cref="IOException"/>
  366. private void WriteMessage(Message message)
  367. {
  368. int sendLimit = m_recordLayer.GetSendLimit();
  369. int fragmentLimit = sendLimit - MESSAGE_HEADER_LENGTH;
  370. // TODO Support a higher minimum fragment size?
  371. if (fragmentLimit < 1)
  372. {
  373. // TODO Should we be throwing an exception here?
  374. throw new TlsFatalAlert(AlertDescription.internal_error);
  375. }
  376. int length = message.Body.Length;
  377. // NOTE: Must still send a fragment if body is empty
  378. int fragment_offset = 0;
  379. do
  380. {
  381. int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit);
  382. WriteHandshakeFragment(message, fragment_offset, fragment_length);
  383. fragment_offset += fragment_length;
  384. }
  385. while (fragment_offset < length);
  386. }
  387. /// <exception cref="IOException"/>
  388. private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
  389. {
  390. RecordLayerBuffer fragment = new RecordLayerBuffer(MESSAGE_HEADER_LENGTH + fragment_length);
  391. TlsUtilities.WriteUint8(message.Type, fragment);
  392. TlsUtilities.WriteUint24(message.Body.Length, fragment);
  393. TlsUtilities.WriteUint16(message.Seq, fragment);
  394. TlsUtilities.WriteUint24(fragment_offset, fragment);
  395. TlsUtilities.WriteUint24(fragment_length, fragment);
  396. fragment.Write(message.Body, fragment_offset, fragment_length);
  397. fragment.SendToRecordLayer(m_recordLayer);
  398. }
  399. private static bool CheckAll(IDictionary<int, DtlsReassembler> inboundFlight)
  400. {
  401. foreach (DtlsReassembler r in inboundFlight.Values)
  402. {
  403. if (r.GetBodyIfComplete() == null)
  404. return false;
  405. }
  406. return true;
  407. }
  408. private static void ResetAll(IDictionary<int, DtlsReassembler> inboundFlight)
  409. {
  410. foreach (DtlsReassembler r in inboundFlight.Values)
  411. {
  412. r.Reset();
  413. }
  414. }
  415. internal class Message
  416. {
  417. private readonly int m_message_seq;
  418. private readonly short m_msg_type;
  419. private readonly byte[] m_body;
  420. internal Message(int message_seq, short msg_type, byte[] body)
  421. {
  422. this.m_message_seq = message_seq;
  423. this.m_msg_type = msg_type;
  424. this.m_body = body;
  425. }
  426. public int Seq
  427. {
  428. get { return m_message_seq; }
  429. }
  430. public short Type
  431. {
  432. get { return m_msg_type; }
  433. }
  434. public byte[] Body
  435. {
  436. get { return m_body; }
  437. }
  438. }
  439. internal class RecordLayerBuffer
  440. : MemoryStream
  441. {
  442. internal RecordLayerBuffer(int size)
  443. : base(size)
  444. {
  445. }
  446. internal void SendToRecordLayer(DtlsRecordLayer recordLayer)
  447. {
  448. byte[] buf = GetBuffer();
  449. int bufLen = Convert.ToInt32(Length);
  450. recordLayer.Send(buf, 0, bufLen);
  451. Dispose();
  452. }
  453. }
  454. internal class Retransmit
  455. : DtlsHandshakeRetransmit
  456. {
  457. private readonly DtlsReliableHandshake m_outer;
  458. internal Retransmit(DtlsReliableHandshake outer)
  459. {
  460. this.m_outer = outer;
  461. }
  462. public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
  463. {
  464. m_outer.ProcessRecord(0, epoch, buf, off, len);
  465. }
  466. }
  467. }
  468. }
  469. #pragma warning restore
  470. #endif