diff --git a/Assets/Mirage/Runtime/NetworkClient.cs b/Assets/Mirage/Runtime/NetworkClient.cs index 1ba53b495c4..567f74b6ecd 100644 --- a/Assets/Mirage/Runtime/NetworkClient.cs +++ b/Assets/Mirage/Runtime/NetworkClient.cs @@ -142,16 +142,16 @@ public void Connect(string address = null, ushort? port = null) if (logger.LogEnabled()) logger.Log($"Client connecting to endpoint: {endPoint}"); var socket = SocketFactory.CreateClientSocket(); - var maxPacketSize = SocketFactory.MaxPacketSize; + var socketInfo = SocketFactory.SocketInfo; MessageHandler = new MessageHandler(World, DisconnectOnException, RethrowException); var dataHandler = new DataHandler(MessageHandler); Metrics = EnablePeerMetrics ? new Metrics(MetricsSize) : null; var config = PeerConfig ?? new Config(); - NetworkWriterPool.Configure(maxPacketSize); + NetworkWriterPool.Configure(socketInfo.MaxSize); - _peer = new Peer(socket, maxPacketSize, dataHandler, config, LogFactory.GetLogger(), Metrics); + _peer = new Peer(socket, socketInfo, dataHandler, config, LogFactory.GetLogger(), Metrics); _peer.OnConnected += Peer_OnConnected; _peer.OnConnectionFailed += Peer_OnConnectionFailed; _peer.OnDisconnected += Peer_OnDisconnected; diff --git a/Assets/Mirage/Runtime/NetworkServer.cs b/Assets/Mirage/Runtime/NetworkServer.cs index e78a88d8893..18c80f259cf 100644 --- a/Assets/Mirage/Runtime/NetworkServer.cs +++ b/Assets/Mirage/Runtime/NetworkServer.cs @@ -260,14 +260,14 @@ public void StartServer(NetworkClient localClient = null) // If not, that's okay. Some games use a non-listening server for their single player game mode (Battlefield, Call of Duty...) if (Listening) { - var maxPacketSize = SocketFactory.MaxPacketSize; - NetworkWriterPool.Configure(maxPacketSize); + var socketInfo = SocketFactory.SocketInfo; + NetworkWriterPool.Configure(socketInfo.MaxSize); // Create a server specific socket. var socket = SocketFactory.CreateServerSocket(); // Tell the peer to use that newly created socket. - _peer = new Peer(socket, maxPacketSize, dataHandler, config, LogFactory.GetLogger(), Metrics); + _peer = new Peer(socket, socketInfo, dataHandler, config, LogFactory.GetLogger(), Metrics); _peer.OnConnected += Peer_OnConnected; _peer.OnDisconnected += Peer_OnDisconnected; // Bind it to the endpoint. diff --git a/Assets/Mirage/Runtime/SocketLayer/ByteBuffer.cs b/Assets/Mirage/Runtime/SocketLayer/ByteBuffer.cs index 6eaa58b8204..32fccf5e834 100644 --- a/Assets/Mirage/Runtime/SocketLayer/ByteBuffer.cs +++ b/Assets/Mirage/Runtime/SocketLayer/ByteBuffer.cs @@ -3,7 +3,7 @@ namespace Mirage.SocketLayer { /// - /// Warpper around a byte[] that belongs to a + /// Wrapper around a byte[] that belongs to a /// public sealed class ByteBuffer : IDisposable { diff --git a/Assets/Mirage/Runtime/SocketLayer/Config.cs b/Assets/Mirage/Runtime/SocketLayer/Config.cs index 922867e4be6..d89c6c3bb6b 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Config.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Config.cs @@ -98,11 +98,6 @@ public class Config /// max value is 255 /// public int MaxReliableFragments = 50; - - /// - /// Enable if the Socket you are using has its own Reliable layer. For example using Websocket, which is TCP. - /// - public bool DisableReliableLayer = false; #endregion } } diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/AckSystem.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/AckSystem.cs index aa07075970a..a0213b7c4de 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Connection/AckSystem.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/AckSystem.cs @@ -237,7 +237,7 @@ private bool ShouldSendEmptyAck() [MethodImpl(MethodImplOptions.AggressiveInlining)] private void Send(byte[] final, int length) { - _connection.SendRaw(final, length); + _connection.SendRaw(final, length, SendMode.Unreliable); OnSend(); } @@ -261,7 +261,7 @@ private void SendAck() ByteUtils.WriteUShort(final.array, ref offset, _latestAckSequence); ByteUtils.WriteULong(final.array, ref offset, _ackMask); - _connection.SendRaw(final.array, offset); + _connection.SendRaw(final.array, offset, SendMode.Unreliable); Send(final.array, offset); } } diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs index e5f4168cde2..d4668d4aa36 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/Batch.cs @@ -1,15 +1,18 @@ using System; +using UnityEngine; namespace Mirage.SocketLayer { public abstract class Batch { public const int MESSAGE_LENGTH_SIZE = 2; + public const int MAX_BATCH_SIZE = ushort.MaxValue; private readonly int _maxPacketSize; public Batch(int maxPacketSize) { + _maxPacketSize = maxPacketSize; } @@ -39,7 +42,7 @@ public void AddMessage(byte[] message, int offset, int length) AddToBatch(message, offset, length); } - private void AddToBatch(byte[] message, int offset, int length) + protected virtual void AddToBatch(byte[] message, int offset, int length) { var batch = GetBatch(); ref var batchLength = ref GetBatchLength(); @@ -57,18 +60,21 @@ public void Flush() public class ArrayBatch : Batch { - private readonly Action _send; + private readonly IRawConnection _connection; private readonly PacketType _packetType; - + private readonly SendMode _sendMode; private readonly byte[] _batch; + protected readonly ILogger _logger; private int _batchLength; - public ArrayBatch(int maxPacketSize, Action send, PacketType reliable) + public ArrayBatch(int maxPacketSize, ILogger logger, IRawConnection connection, PacketType reliable, SendMode sendMode) : base(maxPacketSize) { + _logger = logger; _batch = new byte[maxPacketSize]; - _send = send; + _connection = connection; _packetType = reliable; + _sendMode = sendMode; } protected override bool Created => _batchLength > 0; @@ -84,9 +90,32 @@ protected override void CreateNewBatch() protected override void SendAndReset() { - _send.Invoke(_batch, _batchLength); + _connection.SendRaw(_batch, _batchLength, _sendMode); _batchLength = 0; } + + protected override void AddToBatch(byte[] message, int offset, int length) + { + if (length > MAX_BATCH_SIZE) + { + var batch = GetBatch(); + ref var batchLength = ref GetBatchLength(); + _logger.Assert(batchLength == 1, "if length is large, then batch should be new (empty) packet"); + + // write zero as flag for large message, + // normal message will have atleast 1 length + ByteUtils.WriteUShort(batch, ref batchLength, 0); + Buffer.BlockCopy(message, offset, batch, batchLength, length); + batchLength += length; + + // we can send right away, nothing else will fit in this message + SendAndReset(); + } + else + { + base.AddToBatch(message, offset, length); + } + } } public class ReliableBatch : Batch, IDisposable diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/Connection.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/Connection.cs index 779613c6aa5..7c19346d1be 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Connection/Connection.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/Connection.cs @@ -4,10 +4,10 @@ namespace Mirage.SocketLayer { - internal abstract class Connection : IConnection + internal abstract class Connection : IConnection, IRawConnection { protected readonly ILogger _logger; - protected readonly int _maxPacketSize; + protected readonly SocketInfo _socketInfo; protected readonly Peer _peer; protected readonly IDataHandler _dataHandler; @@ -55,11 +55,11 @@ public ConnectionState State public bool Connected => State == ConnectionState.Connected; - protected Connection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, int maxPacketSize, Time time, ILogger logger, Metrics metrics) + protected Connection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, SocketInfo socketInfo, Time time, ILogger logger, Metrics metrics) { _peer = peer; _logger = logger; - _maxPacketSize = maxPacketSize; + _socketInfo = socketInfo; EndPoint = endPoint ?? throw new ArgumentNullException(nameof(endPoint)); _dataHandler = dataHandler ?? throw new ArgumentNullException(nameof(dataHandler)); @@ -73,6 +73,11 @@ protected Connection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Co _metrics = metrics; } + void IRawConnection.SendRaw(byte[] packet, int length, SendMode mode) + { + _peer.Send(this, packet, length, mode); + } + public override string ToString() { return $"[{EndPoint}]"; @@ -207,14 +212,30 @@ private void UpdateConnected() protected void HandleReliableBatched(byte[] array, int offset, int packetLength, PacketType packetType) { + var firstPacket = true; while (offset < packetLength) { - var length = ByteUtils.ReadUShort(array, ref offset); + var length = (int)ByteUtils.ReadUShort(array, ref offset); + if (length == 0)// not batched + { + if (!firstPacket) + { + // only first message can be not batched + Disconnect(DisconnectReason.InvalidPacket); + return; + } + + _logger.Assert(offset == 3); + // set real length + length = packetLength - offset; + } + var message = new ArraySegment(array, offset, length); offset += length; _metrics?.OnReceiveMessage(packetType, length); _dataHandler.ReceiveMessage(this, message); + firstPacket = false; } } diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/IConnection.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/IConnection.cs index 4dd044c20d4..6110e8bd2c4 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Connection/IConnection.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/IConnection.cs @@ -11,7 +11,7 @@ public interface IRawConnection /// packet given to this function as assumed to already have a header /// /// header and messages - void SendRaw(byte[] packet, int length); + void SendRaw(byte[] packet, int length, SendMode mode); } /// diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/NoReliableConnection.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/NoReliableConnection.cs index 646007ebe3e..4f947bb3123 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Connection/NoReliableConnection.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/NoReliableConnection.cs @@ -4,7 +4,7 @@ namespace Mirage.SocketLayer { /// - /// Connection that does not run its own reliablity layer, good for TCP sockets + /// Connection that does not run its own reliability layer, good for TCP sockets /// internal sealed class NoReliableConnection : Connection { @@ -12,20 +12,12 @@ internal sealed class NoReliableConnection : Connection private readonly Batch _nextBatchReliable; - internal NoReliableConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, int maxPacketSize, Time time, ILogger logger, Metrics metrics) - : base(peer, endPoint, dataHandler, config, maxPacketSize, time, logger, metrics) + internal NoReliableConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, SocketInfo socketInfo, Time time, ILogger logger, Metrics metrics) + : base(peer, endPoint, dataHandler, config, socketInfo, time, logger, metrics) { - _nextBatchReliable = new ArrayBatch(maxPacketSize, SendBatchInternal, PacketType.Reliable); + Debug.Assert(socketInfo.Reliability == SocketReliability.Reliable); - if (maxPacketSize > ushort.MaxValue) - { - throw new ArgumentException($"Max package size can not bigger than {ushort.MaxValue}. NoReliableConnection uses 2 bytes for message length, maxPacketSize over that value will mean that message will be incorrectly batched."); - } - } - - private void SendBatchInternal(byte[] batch, int length) - { - _peer.Send(this, batch, length); + _nextBatchReliable = new ArrayBatch(socketInfo.MaxReliableSize, logger, this, PacketType.Reliable, SendMode.Reliable); } // just sue SendReliable for unreliable/notify @@ -51,9 +43,9 @@ public override void SendReliable(byte[] message, int offset, int length) { ThrowIfNotConnectedOrConnecting(); - if (length + HEADER_SIZE > _maxPacketSize) + if (length + HEADER_SIZE > _socketInfo.MaxReliableSize) { - throw new ArgumentException($"Message is bigger than MTU, size:{length} but max message size is {_maxPacketSize - HEADER_SIZE}"); + throw new ArgumentException($"Message is bigger than MTU, size:{length} but max message size is {_socketInfo.MaxReliableSize - HEADER_SIZE}"); } _nextBatchReliable.AddMessage(message, offset, length); diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/PassthroughConnection.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/PassthroughConnection.cs new file mode 100644 index 00000000000..59fed7b60f9 --- /dev/null +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/PassthroughConnection.cs @@ -0,0 +1,144 @@ +using System; +using UnityEngine; + +namespace Mirage.SocketLayer +{ + internal class PassthroughConnection : Connection, IRawConnection + { + private const int HEADER_SIZE = 1 + Batch.MESSAGE_LENGTH_SIZE; + + private readonly Batch _reliableBatch; + private readonly Batch _unreliableBatch; + private readonly AckSystem _ackSystem; + + public PassthroughConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, SocketInfo socketInfo, Time time, Pool bufferPool, ILogger logger, Metrics metrics) + : base(peer, endPoint, dataHandler, config, socketInfo, time, logger, metrics) + { + _reliableBatch = new ArrayBatch(socketInfo.MaxReliableSize, logger, this, PacketType.Reliable, SendMode.Reliable); + _unreliableBatch = new ArrayBatch(socketInfo.MaxUnreliableSize, logger, this, PacketType.Unreliable, SendMode.Unreliable); + _ackSystem = new AckSystem(this, config, socketInfo.MaxUnreliableSize, time, bufferPool, logger, metrics); + } + + /// + /// single message, batched by AckSystem + /// + /// + public override void SendReliable(byte[] message, int offset, int length) + { + ThrowIfNotConnectedOrConnecting(); + + if (length + HEADER_SIZE > _socketInfo.MaxReliableSize) + { + throw new ArgumentException($"Message is bigger than MTU, size:{length} but max message size is {_socketInfo.MaxReliableSize - HEADER_SIZE}"); + } + + _reliableBatch.AddMessage(message, offset, length); + _metrics?.OnSendMessageReliable(length); + } + + public override void SendUnreliable(byte[] message, int offset, int length) + { + ThrowIfNotConnectedOrConnecting(); + + if (length + HEADER_SIZE > _socketInfo.MaxUnreliableSize) + { + throw new ArgumentException($"Message is bigger than MTU, size:{length} but max message size is {_socketInfo.MaxUnreliableSize - HEADER_SIZE}"); + } + + _unreliableBatch.AddMessage(message, offset, length); + _metrics?.OnSendMessageUnreliable(length); + } + + /// + /// Use version for non-alloc + /// + public override INotifyToken SendNotify(byte[] packet, int offset, int length) + { + ThrowIfNotConnectedOrConnecting(); + var token = _ackSystem.SendNotify(packet, offset, length); + _metrics?.OnSendMessageNotify(length); + return token; + } + + /// + /// Use version for non-alloc + /// + public override void SendNotify(byte[] packet, int offset, int length, INotifyCallBack callBacks) + { + ThrowIfNotConnectedOrConnecting(); + _ackSystem.SendNotify(packet, offset, length, callBacks); + _metrics?.OnSendMessageNotify(length); + } + + internal override void ReceiveUnreliablePacket(Packet packet) + { + HandleReliableBatched(packet.Buffer.array, 1, packet.Length, PacketType.Unreliable); + } + + internal override void ReceiveReliablePacket(Packet packet) + { + HandleReliableBatched(packet.Buffer.array, 1, packet.Length, PacketType.Reliable); + } + + internal override void ReceiveReliableFragment(Packet packet) => throw new NotSupportedException(); + + internal override void ReceiveNotifyPacket(Packet packet) + { + var segment = _ackSystem.ReceiveNotify(packet.Buffer.array, packet.Length); + if (segment != default) + { + _metrics?.OnReceiveMessageNotify(packet.Length); + _dataHandler.ReceiveMessage(this, segment); + } + } + + internal override void ReceiveNotifyAck(Packet packet) + { + _ackSystem.ReceiveAck(packet.Buffer.array); + } + + public override void FlushBatch() + { + _ackSystem.Update(); + _reliableBatch.Flush(); + _unreliableBatch.Flush(); + } + + internal override bool IsValidSize(Packet packet) + { + const int minPacketSize = 1; + + var length = packet.Length; + if (length < minPacketSize) + return false; + + // Min size of message given to Mirage + const int minMessageSize = 2; + + const int minCommandSize = 2; + const int minUnreliableSize = 1 + minMessageSize; + + switch (packet.Type) + { + case PacketType.Command: + return length >= minCommandSize; + + case PacketType.Reliable: + case PacketType.Unreliable: + return length >= minUnreliableSize; + + case PacketType.Notify: + return length >= AckSystem.NOTIFY_HEADER_SIZE + minMessageSize; + case PacketType.Ack: + return length >= AckSystem.ACK_HEADER_SIZE; + case PacketType.ReliableFragment: + // not supported + return false; + + default: + case PacketType.KeepAlive: + return true; + } + } + } +} diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/PassthroughConnection.cs.meta b/Assets/Mirage/Runtime/SocketLayer/Connection/PassthroughConnection.cs.meta new file mode 100644 index 00000000000..323b22f9264 --- /dev/null +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/PassthroughConnection.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: bc9c29bb7ffe23349b834ab6a92617c6 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirage/Runtime/SocketLayer/Connection/ReliableConnection.cs b/Assets/Mirage/Runtime/SocketLayer/Connection/ReliableConnection.cs index 0c5fb98a2e7..026089d6b36 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Connection/ReliableConnection.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Connection/ReliableConnection.cs @@ -6,23 +6,17 @@ namespace Mirage.SocketLayer /// /// Objects that represents a connection to/from a server/client. Holds state that is needed to update, send, and receive data /// - internal sealed class ReliableConnection : Connection, IRawConnection, IDisposable + internal sealed class ReliableConnection : Connection, IDisposable { private readonly AckSystem _ackSystem; private readonly Batch _unreliableBatch; - private readonly Pool _bufferPool; - internal ReliableConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, int maxPacketSize, Time time, Pool bufferPool, ILogger logger, Metrics metrics) - : base(peer, endPoint, dataHandler, config, maxPacketSize, time, logger, metrics) + internal ReliableConnection(Peer peer, IEndPoint endPoint, IDataHandler dataHandler, Config config, SocketInfo socketInfo, Time time, Pool bufferPool, ILogger logger, Metrics metrics) + : base(peer, endPoint, dataHandler, config, socketInfo, time, logger, metrics) { - _bufferPool = bufferPool; - _unreliableBatch = new ArrayBatch(_maxPacketSize, SendBatchInternal, PacketType.Unreliable); - _ackSystem = new AckSystem(this, config, maxPacketSize, time, bufferPool, logger, metrics); - } - - private void SendBatchInternal(byte[] batch, int length) - { - _peer.Send(this, batch, length); + Debug.Assert(socketInfo.Reliability == SocketReliability.Unreliable); + _unreliableBatch = new ArrayBatch(socketInfo.MaxUnreliableSize, logger, this, PacketType.Unreliable, SendMode.Unreliable); + _ackSystem = new AckSystem(this, config, socketInfo.MaxUnreliableSize, time, bufferPool, logger, metrics); } public void Dispose() @@ -30,11 +24,6 @@ public void Dispose() _ackSystem.Dispose(); } - void IRawConnection.SendRaw(byte[] packet, int length) - { - _peer.Send(this, packet, length); - } - /// /// Use version for non-alloc /// @@ -71,9 +60,10 @@ public override void SendUnreliable(byte[] packet, int offset, int length) { ThrowIfNotConnectedOrConnecting(); - if (length + 1 > _maxPacketSize) + // todo allow message up to MaxUnreliableSize-1, by not batching + if (length + 1 + Batch.MESSAGE_LENGTH_SIZE > _socketInfo.MaxUnreliableSize) { - throw new ArgumentException($"Message is bigger than MTU, size:{length} but max Unreliable message size is {_maxPacketSize - 1}"); + throw new ArgumentException($"Message is bigger than MTU, size:{length} but max Unreliable message size is {_socketInfo.MaxUnreliableSize - (1 + Batch.MESSAGE_LENGTH_SIZE)}"); } _unreliableBatch.AddMessage(packet, offset, length); diff --git a/Assets/Mirage/Runtime/SocketLayer/ISocket.cs b/Assets/Mirage/Runtime/SocketLayer/ISocket.cs index b630a602cce..81a0ced7ef0 100644 --- a/Assets/Mirage/Runtime/SocketLayer/ISocket.cs +++ b/Assets/Mirage/Runtime/SocketLayer/ISocket.cs @@ -37,7 +37,7 @@ public interface ISocket /// and make sure not to return above that size /// /// - /// buffer to write recevived packet into + /// buffer to write receive packet into /// where packet came from /// length of packet, should not be above length int Receive(byte[] buffer, out IEndPoint endPoint); @@ -49,7 +49,13 @@ public interface ISocket /// where packet is being sent to /// buffer that contains the packet, starting at index 0 /// length of the packet - void Send(IEndPoint endPoint, byte[] packet, int length); + void Send(IEndPoint endPoint, byte[] packet, int length, SendMode sendMode); + } + + public enum SendMode + { + Unreliable = 0, + Reliable = 1, } /// diff --git a/Assets/Mirage/Runtime/SocketLayer/Peer.cs b/Assets/Mirage/Runtime/SocketLayer/Peer.cs index 4c228294308..ef884e45e15 100644 --- a/Assets/Mirage/Runtime/SocketLayer/Peer.cs +++ b/Assets/Mirage/Runtime/SocketLayer/Peer.cs @@ -42,7 +42,7 @@ public sealed class Peer : IPeer private readonly ISocket _socket; private readonly IDataHandler _dataHandler; private readonly Config _config; - private readonly int _maxPacketSize; + private readonly SocketInfo _socketInfo; private readonly Time _time; private readonly ConnectKeyValidator _connectKeyValidator; private readonly Pool _bufferPool; @@ -61,14 +61,15 @@ public sealed class Peer : IPeer private bool _active; public PoolMetrics PoolMetrics => _bufferPool.Metrics; - public Peer(ISocket socket, int maxPacketSize, IDataHandler dataHandler, Config config = null, ILogger logger = null, Metrics metrics = null) + public Peer(ISocket socket, SocketInfo socketInfo, IDataHandler dataHandler, Config config = null, ILogger logger = null, Metrics metrics = null) { _logger = logger; _metrics = metrics; _config = config ?? new Config(); - _maxPacketSize = maxPacketSize; - if (maxPacketSize < AckSystem.MIN_RELIABLE_HEADER_SIZE + 1) - throw new ArgumentException($"Max packet size too small for AckSystem header", nameof(maxPacketSize)); + _socketInfo = socketInfo; + + if (_socketInfo.Reliability == 0) + throw new ArgumentNullException(nameof(socketInfo), "socketInfo was default and had no values"); _socket = socket ?? throw new ArgumentNullException(nameof(socket)); _dataHandler = dataHandler ?? throw new ArgumentNullException(nameof(dataHandler)); @@ -76,7 +77,7 @@ public Peer(ISocket socket, int maxPacketSize, IDataHandler dataHandler, Config _connectKeyValidator = new ConnectKeyValidator(_config.key); - _bufferPool = new Pool(ByteBuffer.CreateNew, maxPacketSize, _config.BufferPoolStartSize, _config.BufferPoolMaxSize, _logger); + _bufferPool = new Pool(ByteBuffer.CreateNew, socketInfo.MaxSize, _config.BufferPoolStartSize, _config.BufferPoolMaxSize, _logger); Application.quitting += Application_quitting; } @@ -132,13 +133,13 @@ public void Close() _socket.Close(); } - internal void Send(Connection connection, byte[] data, int length) + internal void Send(Connection connection, byte[] data, int length, SendMode mode) { // connecting connections can send connect messages so is allowed // todo check connected before message are sent from high level _logger?.Assert(connection.State == ConnectionState.Connected || connection.State == ConnectionState.Connecting || connection.State == ConnectionState.Disconnected, connection.State); - _socket.Send(connection.EndPoint, data, length); + _socket.Send(connection.EndPoint, data, length, mode); _metrics?.OnSend(length); connection.SetSendTime(); @@ -161,7 +162,7 @@ internal void SendCommandUnconnected(IEndPoint endPoint, Commands command, byte? { var length = CreateCommandPacket(buffer, command, extra); - _socket.Send(endPoint, buffer.array, length); + _socket.Send(endPoint, buffer.array, length, SendMode.Reliable); _metrics?.OnSendUnconnected(length); if (_logger.Enabled(LogType.Log)) { @@ -176,7 +177,7 @@ internal void SendConnectRequest(Connection connection) { var length = CreateCommandPacket(buffer, Commands.ConnectRequest, null); _connectKeyValidator.CopyTo(buffer.array); - Send(connection, buffer.array, length + _connectKeyValidator.KeyLength); + Send(connection, buffer.array, length + _connectKeyValidator.KeyLength, SendMode.Reliable); } } @@ -185,7 +186,7 @@ internal void SendCommand(Connection connection, Commands command, byte? extra = using (var buffer = _bufferPool.Take()) { var length = CreateCommandPacket(buffer, command, extra); - Send(connection, buffer.array, length); + Send(connection, buffer.array, length, SendMode.Reliable); } } @@ -217,7 +218,7 @@ internal void SendKeepAlive(Connection connection) using (var buffer = _bufferPool.Take()) { buffer.array[0] = (byte)PacketType.KeepAlive; - Send(connection, buffer.array, 1); + Send(connection, buffer.array, 1, SendMode.Unreliable); } } @@ -247,8 +248,8 @@ private void ReceiveLoop() var length = _socket.Receive(buffer.array, out var receiveEndPoint); // this should never happen. buffer size is only MTU, if socket returns higher length then it has a bug. - if (length > _maxPacketSize) - throw new IndexOutOfRangeException($"Socket returned length above MTU. MaxPacketSize:{_maxPacketSize} length:{length}"); + if (length > _socketInfo.MaxSize) + throw new IndexOutOfRangeException($"Socket returned length above MTU. MaxPacketSize:{_socketInfo.MaxSize} length:{length}"); var packet = new Packet(buffer, length); @@ -368,10 +369,18 @@ private void HandleCommand(Connection connection, Packet packet) private void HandleNewConnection(IEndPoint endPoint, Packet packet) { + // first check if new packet is valid // if invalid, then reject without reason - if (!Validate(packet)) { return; } - + // key could be anything, so any message over 2 could be key. + var minLength = 2; + if (packet.Length < minLength) + return; + if (packet.Type != PacketType.Command) + return; + if (packet.Command != Commands.ConnectRequest) + return; + // then process other reject reasons if (AtMaxConnections()) { RejectConnectionWithReason(endPoint, RejectReason.ServerFull); @@ -388,23 +397,6 @@ private void HandleNewConnection(IEndPoint endPoint, Packet packet) AcceptNewConnection(endPoint); } } - - private bool Validate(Packet packet) - { - // key could be anything, so any message over 2 could be key. - var minLength = 2; - if (packet.Length < minLength) - return false; - - if (packet.Type != PacketType.Command) - return false; - - if (packet.Command != Commands.ConnectRequest) - return false; - - return true; - } - private bool AtMaxConnections() { return _connections.Count >= _config.MaxConnections; @@ -425,16 +417,21 @@ private Connection CreateNewConnection(IEndPoint newEndPoint) var endPoint = newEndPoint?.CreateCopy(); Connection connection; - if (_config.DisableReliableLayer) - { - connection = new NoReliableConnection(this, endPoint, _dataHandler, _config, _maxPacketSize, _time, _logger, _metrics); - } - else + + switch (_socketInfo.Reliability) { - connection = new ReliableConnection(this, endPoint, _dataHandler, _config, _maxPacketSize, _time, _bufferPool, _logger, _metrics); + default:// note: default will never happen because it will throw in constructor + case SocketReliability.Unreliable: + connection = new ReliableConnection(this, endPoint, _dataHandler, _config, _socketInfo, _time, _bufferPool, _logger, _metrics); + break; + case SocketReliability.Reliable: + connection = new NoReliableConnection(this, endPoint, _dataHandler, _config, _socketInfo, _time, _logger, _metrics); + break; + case SocketReliability.Both: + connection = new PassthroughConnection(this, endPoint, _dataHandler, _config, _socketInfo, _time, _bufferPool, _logger, _metrics); + break; } - connection.SetReceiveTime(); _connections.Add(endPoint, connection); return connection; diff --git a/Assets/Mirage/Runtime/SocketLayer/SocketFactory.cs b/Assets/Mirage/Runtime/SocketLayer/SocketFactory.cs index 580efe01764..2448dbff1d3 100644 --- a/Assets/Mirage/Runtime/SocketLayer/SocketFactory.cs +++ b/Assets/Mirage/Runtime/SocketLayer/SocketFactory.cs @@ -32,9 +32,9 @@ public interface IHasPort [HelpURL("https://miragenet.github.io/Mirage/docs/general/sockets#changing-a-socket")] public abstract class SocketFactory : MonoBehaviour { - /// Max size for packets sent to or received from Socket + /// Gets info about the socket being created /// Called once when Sockets are created - public abstract int MaxPacketSize { get; } + public abstract SocketInfo SocketInfo { get; } /// Creates a to be used by on the server /// Throw when Server is not supported on current platform diff --git a/Assets/Mirage/Runtime/SocketLayer/SocketInfo.cs b/Assets/Mirage/Runtime/SocketLayer/SocketInfo.cs new file mode 100644 index 00000000000..b4d8977033b --- /dev/null +++ b/Assets/Mirage/Runtime/SocketLayer/SocketInfo.cs @@ -0,0 +1,97 @@ +using System; + +namespace Mirage.SocketLayer +{ + public enum SocketReliability + { + // note: 0 is unset, it will be used to check if SocketInfo is default or not + + /// + /// all packets are unreliable, eg udp + /// + Unreliable = 1, + + /// + /// all packets are reliable, eg tcp or webSockets + /// + Reliable = 2, + + /// + /// if socket supports both reliable and unreliable, eg steam or epic relay + /// + Both = 3, + } + + public readonly struct SocketInfo + { + /// + /// How the socket handles reliability + /// + public readonly SocketReliability Reliability; + + /// + /// If socket supports Reliable, what is the max packet size. This should include max Fragmentation size if socket handles that + /// + public readonly int MaxReliableSize; + + /// + /// If socket supports Unreliable, what is the max packet size + /// + public readonly int MaxUnreliableSize; + + /// + /// Will the Socket handle Fragmentation for Reliable messages + /// if false, Mirage will fragment message before sending them to socket + /// + public readonly bool ReliableFragmentation; + + /// + /// Max size required by either reliable or unreliable + /// + public readonly int MaxSize; + + public SocketInfo(SocketReliability reliability, int maxReliableSize, int maxUnreliableSize, bool reliableFragmentation) + { + Reliability = reliability; + MaxReliableSize = maxReliableSize; + MaxUnreliableSize = maxUnreliableSize; + ReliableFragmentation = reliableFragmentation; + MaxSize = Math.Max(MaxReliableSize, MaxUnreliableSize); + + // this smallest size that a socket must support to work with Mirage + // note: this number is arbitrary, but is a reasonable size, any smaller and too many packets will need to be fragmented and sent + const int minMessageSize = 100; + switch (reliability) + { + case SocketReliability.Unreliable: + // Mirage will handle reliability, so max size must be big enough so that header can fit + if (MaxUnreliableSize < AckSystem.MIN_RELIABLE_HEADER_SIZE + minMessageSize) + throw new ArgumentException($"Max unreliable size too small for AckSystem header", nameof(maxUnreliableSize)); + break; + case SocketReliability.Reliable: + // Mirage will just batch message and send them to socket + if (MaxReliableSize < Batch.MESSAGE_LENGTH_SIZE + minMessageSize) + throw new ArgumentException($"Max reliable size too small for Batch header", nameof(maxUnreliableSize)); + break; + + case SocketReliability.Both: + if (MaxUnreliableSize < AckSystem.NOTIFY_HEADER_SIZE + minMessageSize) + throw new ArgumentException($"Max unreliable size too small for Notify header", nameof(maxUnreliableSize)); + + if (MaxReliableSize < Batch.MESSAGE_LENGTH_SIZE + minMessageSize) + throw new ArgumentException($"Max reliable size too small for AckSystem header", nameof(maxUnreliableSize)); + break; + } + + + if (MaxReliableSize > ushort.MaxValue) + { + throw new ArgumentException($"Max package size can not bigger than {ushort.MaxValue}. NoReliableConnection uses 2 bytes for message length, maxPacketSize over that value will mean that message will be incorrectly batched."); + } + if (MaxUnreliableSize > ushort.MaxValue) + { + throw new ArgumentException($"Max package size can not bigger than {ushort.MaxValue}. NoReliableConnection uses 2 bytes for message length, maxPacketSize over that value will mean that message will be incorrectly batched."); + } + } + } +} diff --git a/Assets/Mirage/Runtime/SocketLayer/SocketInfo.cs.meta b/Assets/Mirage/Runtime/SocketLayer/SocketInfo.cs.meta new file mode 100644 index 00000000000..0661937f813 --- /dev/null +++ b/Assets/Mirage/Runtime/SocketLayer/SocketInfo.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 66d3e84fcc908234fb336943fdea88ab +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Assets/Mirage/Runtime/Sockets/Udp/NanoSocket.cs b/Assets/Mirage/Runtime/Sockets/Udp/NanoSocket.cs index c77e9830878..2d6f7900078 100644 --- a/Assets/Mirage/Runtime/Sockets/Udp/NanoSocket.cs +++ b/Assets/Mirage/Runtime/Sockets/Udp/NanoSocket.cs @@ -6,7 +6,6 @@ namespace Mirage.Sockets.Udp { - public sealed class NanoSocket : ISocket, IDisposable { public static bool Supported => true; @@ -82,7 +81,7 @@ public int Receive(byte[] buffer, out IEndPoint endPoint) return count; } - public void Send(IEndPoint endPoint, byte[] packet, int length) + public void Send(IEndPoint endPoint, byte[] packet, int length, SendMode _) { var nanoEndPoint = (NanoEndPoint)endPoint; UDP.Send(socket, ref nanoEndPoint.address, packet, length); diff --git a/Assets/Mirage/Runtime/Sockets/Udp/UdpSocket.cs b/Assets/Mirage/Runtime/Sockets/Udp/UdpSocket.cs index 7bff84f1fcb..c6754747b54 100644 --- a/Assets/Mirage/Runtime/Sockets/Udp/UdpSocket.cs +++ b/Assets/Mirage/Runtime/Sockets/Udp/UdpSocket.cs @@ -94,7 +94,7 @@ public int Receive(byte[] buffer, out IEndPoint endPoint) return c; } - public void Send(IEndPoint endPoint, byte[] packet, int length) + public void Send(IEndPoint endPoint, byte[] packet, int length, SendMode _) { var netEndPoint = ((EndPointWrapper)endPoint).inner; socket.SendTo(packet, length, SocketFlags.None, netEndPoint); diff --git a/Assets/Mirage/Runtime/Sockets/Udp/UdpSocketFactory.cs b/Assets/Mirage/Runtime/Sockets/Udp/UdpSocketFactory.cs index 3a9121d8dd1..0559b13e2ef 100644 --- a/Assets/Mirage/Runtime/Sockets/Udp/UdpSocketFactory.cs +++ b/Assets/Mirage/Runtime/Sockets/Udp/UdpSocketFactory.cs @@ -26,8 +26,7 @@ public sealed class UdpSocketFactory : SocketFactory, IHasAddress, IHasPort [Header("NanoSocket-specific Options")] public int BufferSize = 256 * 1024; - - public override int MaxPacketSize => UdpMTU.MaxPacketSize; + public override SocketInfo SocketInfo => new SocketInfo(SocketReliability.Unreliable, 0, UdpMTU.MaxPacketSize, false); // Determines if we can use NanoSockets for socket-level IO. This will be true if either: // - We *want* to use native library explicitly. diff --git a/Assets/Tests/SocketLayer/AckSystem/NoReliableConnectionTest.cs b/Assets/Tests/SocketLayer/AckSystem/NoReliableConnectionTest.cs index 9bb8ecdacfa..2adf7bffb1c 100644 --- a/Assets/Tests/SocketLayer/AckSystem/NoReliableConnectionTest.cs +++ b/Assets/Tests/SocketLayer/AckSystem/NoReliableConnectionTest.cs @@ -5,30 +5,27 @@ using NSubstitute; using NUnit.Framework; -namespace Mirage.SocketLayer.Tests.AckSystemTests +namespace Mirage.SocketLayer.Tests { [Category("SocketLayer")] - public class NoReliableConnectionTest + public abstract class ConnectionTestBase { - private const int MAX_PACKET_SIZE = 100; - - private IConnection _connection; - private byte[] _buffer; - private Config _config; - private PeerInstance _peerInstance; - private Pool _bufferPool; - private readonly Random rand = new Random(); - private byte[] _sentArray; - - private ISocket Socket => _peerInstance.socket; + protected IConnection _connection; + protected byte[] _buffer; + protected Config _config; + protected PeerInstance _peerInstance; + protected Pool _bufferPool; + protected readonly Random rand = new Random(); + protected List _sentArrays = new List(); + + protected ISocket Socket => _peerInstance.socket; + protected abstract Config CreateConfig(); + protected virtual int MAX_PACKET_SIZE => 100; [SetUp] public void Setup() { - _config = new Config - { - DisableReliableLayer = true, - }; + _config = CreateConfig(); _peerInstance = new PeerInstance(_config, maxPacketSize: MAX_PACKET_SIZE); _bufferPool = new Pool(ByteBuffer.CreateNew, MAX_PACKET_SIZE, 0, 100); @@ -41,16 +38,122 @@ public void Setup() } // clear calls, Connect will have sent one + _sentArrays.Clear(); Socket.ClearReceivedCalls(); Socket.When(x => x.Send(Arg.Any(), Arg.Any(), Arg.Any())) .Do(x => { - var arg = (byte[])x.Args()[1]; + var packet = (byte[])x.Args()[1]; + var length = (int)x.Args()[2]; // create copy - _sentArray = arg.ToArray(); + _sentArrays.Add(packet.Take(length).ToArray()); }); } + + protected void AssertSentPacket(PacketType type, IEnumerable messageLengths) + { + var totalLength = 1 + (2 * messageLengths.Count()) + messageLengths.Sum(); + + // only 1 at any length + Socket.Received(1).Send(Arg.Any(), Arg.Any(), Arg.Any()); + // but also check we received length + Socket.Received(1).Send(Arg.Any(), Arg.Any(), totalLength); + + // check packet was correct + CheckMessage(type, 0, 1, messageLengths); + + // clear calls after, so we are ready to process next message + Socket.ClearReceivedCalls(); + _sentArrays.Clear(); + } + + protected void CheckMessage(PacketType type, int sentIndex, int sendCount, IEnumerable messageLengths, int skipHeader = 0) + { + Assert.That(_sentArrays.Count, Is.EqualTo(sendCount)); + var packet = _sentArrays[sentIndex]; + if (packet[0] != (byte)type) + Assert.Fail($"First byte should be the packet type, {type}, it was {(PacketType)packet[0]} instead"); + + var offset = 1; + foreach (var length in messageLengths) + { + if (skipHeader == 0) + { + var ln = ByteUtils.ReadUShort(packet, ref offset); + if (ln != length) + Assert.Fail($"Length at offset {offset - 2} was incorrect.\n Expected:{length}\n But war:{ln}"); + + + for (var i = 0; i < length; i++) + { + if (packet[offset + i] != _buffer[i]) + Assert.Fail($"Value at offset {offset + i} was incorrect.\n Expected:{_buffer[i]}\n But war:{packet[offset + i]}"); + + } + offset += length; + } + else + { + offset += skipHeader; + for (var i = 0; i < length; i++) + { + if (packet[offset + i] != _buffer[i]) + Assert.Fail($"Value at offset {offset + i} was incorrect.\n Expected:{_buffer[i]}\n But war:{packet[offset + i]}"); + + } + offset += length; + } + } + + Assert.That(offset, Is.EqualTo(packet.Length)); + } + + protected void SendIntoBatch(int length, bool reliable, ref int total, List currentBatch) + { + // will write length+2 + var newTotal = total + 2 + length; + if (newTotal > MAX_PACKET_SIZE) + { + Send(reliable, _buffer, length); + // was over max, so should have sent + AssertSentPacket(reliable ? PacketType.Reliable : PacketType.Unreliable, currentBatch); + + currentBatch.Clear(); + // new batch + total = 1 + 2 + length; + } + else + { + Send(reliable, _buffer, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + total = newTotal; + } + currentBatch.Add(length); + } + + protected void Send(bool reliable, byte[] buffer, int length) + { + if (reliable) + _connection.SendReliable(buffer, 0, length); + else + _connection.SendUnreliable(buffer, 0, length); + } + } + + [Category("SocketLayer")] + public class NoReliableConnectionTest : ConnectionTestBase + { + private new Connection _connection => (Connection)base._connection; + + protected override Config CreateConfig() + { + return new Config + { + DisableReliableLayer = true, + }; + } + [Test] public void IsNoReliableConnection() { @@ -72,55 +175,74 @@ public void ThrowsIfTooBig() Assert.That(exception, Has.Message.EqualTo(expected.Message)); } - private void AssertSentPacket(IEnumerable messageLengths) + [Test] + public void MessageAreBatched() { - var totalLength = 1 + (2 * messageLengths.Count()) + messageLengths.Sum(); + // max is 100 - // only 1 at any length - Socket.Received(1).Send(Arg.Any(), Arg.Any(), Arg.Any()); - // but also check we received length - Socket.Received(1).Send(Arg.Any(), Arg.Any(), totalLength); + var lessThanBatchLengths = new int[] + { + 20, 40, 30 + }; + var overBatch = 11; - // check packet was correct - CheckMessage(_sentArray, messageLengths); + foreach (var length in lessThanBatchLengths) + { + _connection.SendReliable(_buffer, 0, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } - // clear calls after, so we are ready to process next message - Socket.ClearReceivedCalls(); + // should be 97 in buffer now => 1+(length+2)*3 + _connection.SendReliable(_buffer, 0, overBatch); + AssertSentPacket(PacketType.Reliable, lessThanBatchLengths); } - private void CheckMessage(byte[] packet, IEnumerable messageLengths) + [Test] + [Repeat(100)] + public void MessageAreBatched_Repeat() { - if (packet[0] != (byte)PacketType.Reliable) - Assert.Fail($"First byte was not Reliable, it was {packet[0]} instead"); - - var offset = 1; - foreach (var length in messageLengths) + const int messageCount = 10; + var lengths = new int[messageCount]; + for (var i = 0; i < messageCount; i++) { - var ln = ByteUtils.ReadUShort(packet, ref offset); - if (ln != length) - Assert.Fail($"Length at offset {offset - 2} was incorrect.\n Expected:{length}\n But war:{ln}"); - + lengths[i] = rand.Next(10, MAX_PACKET_SIZE - 3); + } + var currentBatch = new List(); - for (var i = 0; i < length; i++) + var total = 1; + foreach (var length in lengths) + { + // will write length+2 + var newTotal = total + 2 + length; + if (newTotal > MAX_PACKET_SIZE) { - if (packet[offset + i] != _buffer[i]) - Assert.Fail($"Value at offset {offset + i} was incorrect.\n Expected:{_buffer[i]}\n But war:{packet[offset + i]}"); + _connection.SendReliable(_buffer, 0, length); + // was over max, so should have sent + AssertSentPacket(PacketType.Reliable, currentBatch); + currentBatch.Clear(); + // new batch + total = 1 + 2 + length; + } + else + { + _connection.SendReliable(_buffer, 0, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + total = newTotal; } - offset += length; + currentBatch.Add(length); } } [Test] - public void MessageAreBatched() + public void FlushSendsMessageInBatch() { // max is 100 var lessThanBatchLengths = new int[] { - 20, 40, 30 + 20, 40 }; - var overBatch = 11; foreach (var length in lessThanBatchLengths) { @@ -128,9 +250,167 @@ public void MessageAreBatched() Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); } - // should be 97 in buffer now => 1+(length+2)*3 - _connection.SendReliable(_buffer, 0, overBatch); - AssertSentPacket(lessThanBatchLengths); + _connection.FlushBatch(); + AssertSentPacket(PacketType.Reliable, lessThanBatchLengths); + } + + [Test] + public void FlushDoesNotSendEmptyMessage() + { + _connection.FlushBatch(); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + _connection.FlushBatch(); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } + + + [Test] + public void UnbatchesMessageOnReceive() + { + var receive = _bufferPool.Take(); + receive.array[0] = (byte)PacketType.Reliable; + var offset = 1; + AddMessage(receive.array, ref offset, 10); + AddMessage(receive.array, ref offset, 30); + AddMessage(receive.array, ref offset, 20); + + var segments = new List>(); + _peerInstance.dataHandler + .When(x => x.ReceiveMessage(_connection, Arg.Any>())) + .Do(x => segments.Add(x.ArgAt>(1))); + ((NoReliableConnection)_connection).ReceiveReliablePacket(new Packet(receive, offset)); + _peerInstance.dataHandler.Received(3).ReceiveMessage(_connection, Arg.Any>()); + + + Assert.That(segments[0].Count, Is.EqualTo(10)); + Assert.That(segments[1].Count, Is.EqualTo(30)); + Assert.That(segments[2].Count, Is.EqualTo(20)); + Assert.That(segments[0].SequenceEqual(new ArraySegment(_buffer, 0, 10))); + Assert.That(segments[1].SequenceEqual(new ArraySegment(_buffer, 0, 30))); + Assert.That(segments[2].SequenceEqual(new ArraySegment(_buffer, 0, 20))); + } + + private void AddMessage(byte[] receive, ref int offset, int size) + { + ByteUtils.WriteUShort(receive, ref offset, (ushort)size); + Buffer.BlockCopy(_buffer, 0, receive, offset, size); + offset += size; + } + + [Test] + public void SendingToUnreliableUsesReliable() + { + var counts = new List() { 10, 20 }; + _connection.SendUnreliable(_buffer, 0, counts[0]); + _connection.SendUnreliable(_buffer, 0, counts[1]); + _connection.FlushBatch(); + + AssertSentPacket(PacketType.Reliable, counts); + } + + [Test] + public void SendingToNotifyUsesReliable() + { + var counts = new List() { 10, 20 }; + _connection.SendNotify(_buffer, 0, counts[0]); + _connection.SendNotify(_buffer, 0, counts[1]); + _connection.FlushBatch(); + + AssertSentPacket(PacketType.Reliable, counts); + } + [Test] + public void SendingToNotifyTokenUsesReliable() + { + var token = Substitute.For(); + var counts = new List() { 10, 20 }; + _connection.SendNotify(_buffer, 0, counts[0], token); + _connection.SendNotify(_buffer, 0, counts[1], token); + _connection.FlushBatch(); + + AssertSentPacket(PacketType.Reliable, counts); + } + + [Test] + public void NotifyOnDeliveredInvoke() + { + var counts = new List() { 10, 20 }; + var token = _connection.SendNotify(_buffer, 0, counts[0]); + Assert.That(token, Is.TypeOf()); + + var action = Substitute.For(); + token.Delivered += action; + action.Received(1).Invoke(); + } + + [Test] + public void NotifyTokenOnDeliveredInvoke() + { + var token = Substitute.For(); + var counts = new List() { 10, 20 }; + _connection.SendNotify(_buffer, 0, counts[0], token); + token.Received(1).OnDelivered(); + } + } + + + + [Category("SocketLayer")] + public class LargeMessageOnTest : ConnectionTestBase + { + private new Connection _connection => (Connection)base._connection; + protected override int MAX_PACKET_SIZE => ushort.MaxValue + 5000; + + protected override Config CreateConfig() + { + return new Config + { + DisableReliableLayer = true, + }; + } + + [Test] + public void ThrowsIfTooBig() + { + // 3 byte header, so max size is over max + var bigBuffer = new byte[MAX_PACKET_SIZE - 2]; + + var exception = Assert.Throws(() => + { + _connection.SendReliable(bigBuffer); + }); + + var expected = new ArgumentException($"Message is bigger than MTU, size:{bigBuffer.Length} but max message size is {MAX_PACKET_SIZE - 3}"); + Assert.That(exception, Has.Message.EqualTo(expected.Message)); + } + + [Test] + public void MessageOverUshortAreNotBatched() + { + var length = ushort.MaxValue + 10; + + _connection.SendReliable(_buffer, 0, length); + + var totalLength = 1 + 2 + length; + Socket.Received(1).Send(Arg.Any(), Arg.Any(), totalLength); + + // check packet was correct + Assert.That(_sentArrays.Count, Is.EqualTo(1)); + var packet = _sentArrays[0]; + Assert.That(packet.Length, Is.EqualTo(totalLength)); + if (packet[0] != (byte)PacketType.Reliable) + Assert.Fail($"First byte should be the packet type, {PacketType.Reliable}, it was {(PacketType)packet[0]} instead"); + + var offset = 1; + var ln = ByteUtils.ReadUShort(packet, ref offset); + Assert.That(ln, Is.EqualTo(0), "non-batch message should have length zero"); + for (var i = 0; i < length; i++) + { + if (packet[offset + i] != _buffer[i]) + Assert.Fail($"Value at offset {offset + i} was incorrect.\n Expected:{_buffer[i]}\n But war:{packet[offset + i]}"); + } + offset += length; + + Assert.That(offset, Is.EqualTo(packet.Length)); } [Test] @@ -154,7 +434,7 @@ public void MessageAreBatched_Repeat() { _connection.SendReliable(_buffer, 0, length); // was over max, so should have sent - AssertSentPacket(currentBatch); + AssertSentPacket(PacketType.Reliable, currentBatch); currentBatch.Clear(); // new batch @@ -187,7 +467,7 @@ public void FlushSendsMessageInBatch() } _connection.FlushBatch(); - AssertSentPacket(lessThanBatchLengths); + AssertSentPacket(PacketType.Reliable, lessThanBatchLengths); } [Test] @@ -241,7 +521,7 @@ public void SendingToUnreliableUsesReliable() _connection.SendUnreliable(_buffer, 0, counts[1]); _connection.FlushBatch(); - AssertSentPacket(counts); + AssertSentPacket(PacketType.Reliable, counts); } [Test] @@ -252,7 +532,7 @@ public void SendingToNotifyUsesReliable() _connection.SendNotify(_buffer, 0, counts[1]); _connection.FlushBatch(); - AssertSentPacket(counts); + AssertSentPacket(PacketType.Reliable, counts); } [Test] public void SendingToNotifyTokenUsesReliable() @@ -263,7 +543,7 @@ public void SendingToNotifyTokenUsesReliable() _connection.SendNotify(_buffer, 0, counts[1], token); _connection.FlushBatch(); - AssertSentPacket(counts); + AssertSentPacket(PacketType.Reliable, counts); } [Test] @@ -272,6 +552,10 @@ public void NotifyOnDeliveredInvoke() var counts = new List() { 10, 20 }; var token = _connection.SendNotify(_buffer, 0, counts[0]); Assert.That(token, Is.TypeOf()); + + var action = Substitute.For(); + token.Delivered += action; + action.Received(1).Invoke(); } [Test] diff --git a/Assets/Tests/SocketLayer/AckSystem/PassthroughConnectionTest.cs b/Assets/Tests/SocketLayer/AckSystem/PassthroughConnectionTest.cs new file mode 100644 index 00000000000..aa30762cbc6 --- /dev/null +++ b/Assets/Tests/SocketLayer/AckSystem/PassthroughConnectionTest.cs @@ -0,0 +1,290 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using NSubstitute; +using NUnit.Framework; + +namespace Mirage.SocketLayer.Tests +{ + [Category("SocketLayer")] + public class PassthroughConnectionTest : ConnectionTestBase + { + private new Connection _connection => (Connection)base._connection; + + protected override Config CreateConfig() + { + return new Config + { + PassthroughReliableLayer = true, + }; + } + + [Test] + public void IsNoReliableConnection() + { + Assert.That(_connection, Is.TypeOf()); + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void ThrowsIfTooBig(bool reliable) + { + // 3 byte header, so max size is over max + var bigBuffer = new byte[MAX_PACKET_SIZE - 2]; + + var exception = Assert.Throws(() => + { + Send(reliable, bigBuffer, bigBuffer.Length); + }); + + var expected = new ArgumentException($"Message is bigger than MTU, size:{bigBuffer.Length} but max message size is {MAX_PACKET_SIZE - 3}"); + Assert.That(exception, Has.Message.EqualTo(expected.Message)); + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void MessageAreBatched(bool reliable) + { + // max is 100 + + var lessThanBatchLengths = new int[] + { + 20, 40, 30 + }; + var overBatch = 11; + + foreach (var length in lessThanBatchLengths) + { + Send(reliable, _buffer, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } + + // should be 97 in buffer now => 1+(length+2)*3 + Send(reliable, _buffer, overBatch); + AssertSentPacket(reliable ? PacketType.Reliable : PacketType.Unreliable, lessThanBatchLengths); + } + + [Test] + [Repeat(100)] + [TestCase(10)] + [TestCase(100)] + public void MessageAreBatched_Repeat(int messageCount) + { + var lengths = new int[messageCount]; + for (var i = 0; i < messageCount; i++) + lengths[i] = rand.Next(10, MAX_PACKET_SIZE - 3); + + var currentBatch_reliable = new List(); + var currentBatch_unreliable = new List(); + var total_reliable = 1; + var total_unreliable = 1; + foreach (var length in lengths) + { + var reliable = rand.Next(0, 1) == 1; + if (reliable) + SendIntoBatch(length, true, ref total_reliable, currentBatch_reliable); + else + SendIntoBatch(length, false, ref total_unreliable, currentBatch_unreliable); + } + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void FlushSendsMessageInBatch(bool reliable) + { + // max is 100 + + var lessThanBatchLengths = new int[] + { + 20, 40 + }; + + foreach (var length in lessThanBatchLengths) + { + Send(reliable, _buffer, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } + + _connection.FlushBatch(); + AssertSentPacket(reliable ? PacketType.Reliable : PacketType.Unreliable, lessThanBatchLengths); + } + + [Test] + public void FlushSendsMessageInBatch_BothTypes() + { + // max is 100 + + var lessThanBatchLengths_reliable = new int[] + { + 20, 40 + }; + var lessThanBatchLengths_unreliable = new int[] + { + 15, 35, 20 + }; + + foreach (var length in lessThanBatchLengths_reliable) + { + _connection.SendReliable(_buffer, 0, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } + foreach (var length in lessThanBatchLengths_unreliable) + { + _connection.SendUnreliable(_buffer, 0, length); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } + + _connection.FlushBatch(); + + var totalLength_reliable = 1 + (2 * lessThanBatchLengths_reliable.Count()) + lessThanBatchLengths_reliable.Sum(); + var totalLength_unreliable = 1 + (2 * lessThanBatchLengths_unreliable.Count()) + lessThanBatchLengths_unreliable.Sum(); + + // only 2 at any length + Socket.Received(2).Send(Arg.Any(), Arg.Any(), Arg.Any()); + // but also check we received length + Socket.Received(1).Send(Arg.Any(), Arg.Any(), totalLength_reliable); + Socket.Received(1).Send(Arg.Any(), Arg.Any(), totalLength_unreliable); + + // check packet was correct + CheckMessage(PacketType.Reliable, 0, 2, lessThanBatchLengths_reliable); + CheckMessage(PacketType.Unreliable, 1, 2, lessThanBatchLengths_unreliable); + + // clear calls after, so we are ready to process next message + Socket.ClearReceivedCalls(); + _sentArrays.Clear(); + } + + [Test] + public void FlushDoesNotSendEmptyMessage() + { + _connection.FlushBatch(); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + _connection.FlushBatch(); + Socket.DidNotReceive().Send(Arg.Any(), Arg.Any(), Arg.Any()); + } + + [Test] + [TestCase(true)] + [TestCase(false)] + public void UnbatchesMessageOnReceive(bool reliable) + { + var receive = _bufferPool.Take(); + receive.array[0] = (byte)(reliable ? PacketType.Reliable : PacketType.Unreliable); + var offset = 1; + AddMessage(receive.array, ref offset, 10); + AddMessage(receive.array, ref offset, 30); + AddMessage(receive.array, ref offset, 20); + + var segments = new List>(); + _peerInstance.dataHandler + .When(x => x.ReceiveMessage(_connection, Arg.Any>())) + .Do(x => segments.Add(x.ArgAt>(1))); + if (reliable) + _connection.ReceiveReliablePacket(new Packet(receive, offset)); + else + _connection.ReceiveUnreliablePacket(new Packet(receive, offset)); + _peerInstance.dataHandler.Received(3).ReceiveMessage(_connection, Arg.Any>()); + + + Assert.That(segments[0].Count, Is.EqualTo(10)); + Assert.That(segments[1].Count, Is.EqualTo(30)); + Assert.That(segments[2].Count, Is.EqualTo(20)); + Assert.That(segments[0].SequenceEqual(new ArraySegment(_buffer, 0, 10))); + Assert.That(segments[1].SequenceEqual(new ArraySegment(_buffer, 0, 30))); + Assert.That(segments[2].SequenceEqual(new ArraySegment(_buffer, 0, 20))); + } + + private void AddMessage(byte[] receive, ref int offset, int size) + { + ByteUtils.WriteUShort(receive, ref offset, (ushort)size); + Buffer.BlockCopy(_buffer, 0, receive, offset, size); + offset += size; + } + + [Test] + public void SendingToUnreliableUsesUnreliable() + { + var counts = new List() { 10, 20 }; + _connection.SendUnreliable(_buffer, 0, counts[0]); + _connection.SendUnreliable(_buffer, 0, counts[1]); + _connection.FlushBatch(); + + AssertSentPacket(PacketType.Unreliable, counts); + } + + [Test] + public void SendingToNotifyUsesUnreliable() + { + var counts = new List() { 10, 20 }; + _connection.SendNotify(_buffer, 0, counts[0]); + _connection.SendNotify(_buffer, 0, counts[1]); + _connection.FlushBatch(); + + // only 1 at any length + Socket.Received(2).Send(Arg.Any(), Arg.Any(), Arg.Any()); + // but also check we received length + Socket.Received(1).Send(Arg.Any(), Arg.Any(), AckSystem.NOTIFY_HEADER_SIZE + counts[0]); + Socket.Received(1).Send(Arg.Any(), Arg.Any(), AckSystem.NOTIFY_HEADER_SIZE + counts[1]); + + // check packet was correct + CheckMessage(PacketType.Notify, 0, 2, counts.Take(1), AckSystem.NOTIFY_HEADER_SIZE - 1); + CheckMessage(PacketType.Notify, 1, 2, counts.Skip(1).Take(1), AckSystem.NOTIFY_HEADER_SIZE - 1); + + // clear calls after, so we are ready to process next message + Socket.ClearReceivedCalls(); + _sentArrays.Clear(); + } + [Test] + public void SendingToNotifyTokenUsesUnreliable() + { + var token = Substitute.For(); + var counts = new List() { 10, 20 }; + _connection.SendNotify(_buffer, 0, counts[0], token); + _connection.SendNotify(_buffer, 0, counts[1], token); + _connection.FlushBatch(); + + // only 1 at any length + Socket.Received(2).Send(Arg.Any(), Arg.Any(), Arg.Any()); + // but also check we received length + Socket.Received(1).Send(Arg.Any(), Arg.Any(), AckSystem.NOTIFY_HEADER_SIZE + counts[0]); + Socket.Received(1).Send(Arg.Any(), Arg.Any(), AckSystem.NOTIFY_HEADER_SIZE + counts[1]); + + // check packet was correct + CheckMessage(PacketType.Notify, 0, 2, counts.Take(1), AckSystem.NOTIFY_HEADER_SIZE - 1); + CheckMessage(PacketType.Notify, 1, 2, counts.Skip(1).Take(1), AckSystem.NOTIFY_HEADER_SIZE - 1); + + // clear calls after, so we are ready to process next message + Socket.ClearReceivedCalls(); + _sentArrays.Clear(); + } + + [Test] + [Ignore("Not implemented")] + public void NotifyOnDeliveredInvokeAfterReceivingReply() + { + var counts = new List() { 10, 20 }; + var token = _connection.SendNotify(_buffer, 0, counts[0]); + + var action = Substitute.For(); + token.Delivered += action; + action.DidNotReceive().Invoke(); + + // todo receive message here, and then check if Delivered is infact called + } + + [Test] + [Ignore("Not implemented")] + public void NotifyTokenOnDeliveredInvokeAfterReceivingReply() + { + var token = Substitute.For(); + var counts = new List() { 10, 20 }; + _connection.SendNotify(_buffer, 0, counts[0], token); + token.DidNotReceive().OnDelivered(); + + // todo receive message here, and then check if Delivered is infact called + } + } +} diff --git a/Assets/Tests/SocketLayer/AckSystem/PassthroughConnectionTest.cs.meta b/Assets/Tests/SocketLayer/AckSystem/PassthroughConnectionTest.cs.meta new file mode 100644 index 00000000000..3071ba5eb7d --- /dev/null +++ b/Assets/Tests/SocketLayer/AckSystem/PassthroughConnectionTest.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: a78f3f5648f4170458c64f606fa76afd +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: