Skip to content

AdvancedSocketsImplementationinC#

Simon Jackson edited this page Jun 7, 2017 · 1 revision

Advanced Sockets Implementation in C#

Socket programming in C# is much easier than it is in c/c++. Here we use attributes and reflection to re-create and populate incoming packets. Check it out!

Special thanks to Krissa for this great idea!

    // Message Stream
    // class MessageAtrribute

    [AttributeUsage(AttributeTargets.Class)]
    public class MessageAttribute : Attribute
    {
        // Member Variables
        private int m_messageID;
        

        // Constructor
        public MessageAttribute(int messageID)
        {
            m_messageID = messageID;
        }
        

        // Accessor Methods
        public int ID
        {
            get
            {
                return (m_messageID);
            }
        }

        

    }

    


    // class Message

    public class Message
    {

        // MemberVariables

        private Hashtable m_ids = new Hashtable();

        

        // AccessorMethods

        public int ID
        {
            get
            {
                return ((int)m_ids[this.GetType()]);
            }
        }

        

        // Constructor

        public Message()
        {
            if (!m_ids.ContainsKey(this.GetType()))
            {
                object[] attributes = this.GetType().GetCustomAttributes(typeof(
                    MessageAttribute), true);
                int id = 0;
                if (attributes.Length > 0)
                {
                    MessageAttribute ma = (MessageAttribute)attributes[0];
                    id = ma.ID;
                }
                m_ids.Add(this.GetType(), id);
            }
        }

        

        // virtual void Read( int size, BinaryReader reader )

        public virtual void Read(int size, BinaryReader reader)
        {
        }

        

        // virtual void Write( BinaryWriter writer )

        public virtual void Write(BinaryWriter writer)
        {
        }

        

    }

    

    // class MessageStream

    public class MessageStream
    {

        // MemberVariables

        private Socket m_networkSocket;
        private NetworkStream m_networkStream;

        private BinaryReader m_networkReader;
        private BinaryWriter m_networkWriter;

        private bool m_incomingMessageGotSize = false;
        private int m_incomingMessageIndex;
        private int m_incomingMessageID = 0;
        private int m_incomingMessageSize = 0;

        private int m_outgoingMessageID = 0;

        private static Hashtable m_messageTypes = new Hashtable();

        

        // Constructor

        static MessageStream()
        {
            Assembly currentAssembly = Assembly.GetExecutingAssembly();
            foreach (Type currentType in currentAssembly.GetTypes())
            {
                object[] attributes = currentType.GetCustomAttributes(typeof(
                    MessageAttribute), true);
                if (attributes.Length < 1)
                    continue;

                MessageAttribute ma = (MessageAttribute)attributes[0];
                m_messageTypes.Add(ma.ID, currentType);
            }
        }

        

        // MessageStream( Socket s )

        public MessageStream(Socket s)
        {
            m_networkSocket = s;
            m_networkStream = new NetworkStream(m_networkSocket,
                FileAccess.ReadWrite, false);

            m_networkReader = new BinaryReader(m_networkStream,
                System.Text.Encoding.Unicode);
            m_networkWriter = new BinaryWriter(m_networkStream,
                System.Text.Encoding.Unicode);
        }

        

        // bool IsMessageAvailable

        public bool IsMessageAvailable
        {
            get
            {
                // we've started to read a message.
                if (m_incomingMessageGotSize)
                {
                    // if there's less bytes available than we're expecting, then we don't have a complete
                    // message
                    if (m_incomingMessageSize < m_networkSocket.Available)
                        return (false);

                    // if there's enough or more, then we do have a complete message
                    return (true);
                }

                // is there enough for the packet id + size;
                if (m_networkSocket.Available < 12)
                    return (false);

                m_incomingMessageGotSize = true;
                m_incomingMessageIndex = m_networkReader.ReadInt32();
                m_incomingMessageID = m_networkReader.ReadInt32();
                m_incomingMessageSize = m_networkReader.ReadInt32();

                if (m_incomingMessageSize <= m_networkSocket.Available)
                    return (true);

                return (false);
            }
        }

        

        // Message ReceiveMessage()

        public Message ReceiveMessage()
        {
            // determine what type of message it is
            Type messageType = typeof(Message);
            if (m_messageTypes.ContainsKey(m_incomingMessageID))
                messageType = (Type)m_messageTypes[m_incomingMessageID];

            // create the type
            Message message = (Message)messageType.Assembly.CreateInstance(
                messageType.FullName);

            // read message
            message.Read(m_incomingMessageSize, m_networkReader);

            // allow us to read the next message
            m_incomingMessageGotSize = false;

            return (message);
        }

        

        // void SendMessage( Message m )

        public void SendMessage(Message m)
        {
            MemoryStream ms = new MemoryStream();
            BinaryWriter b = new BinaryWriter(ms, System.Text.Encoding.Unicode);
            m.Write(b);

            m_networkWriter.Write(m_outgoingMessageID++);
            m_networkWriter.Write(m.ID);
            m_networkWriter.Write((int)ms.Length);
            byte[] outputBuffer = new byte[ms.Length];
            ms.Position = 0;
            ms.Read(outputBuffer, 0, (int)ms.Length);
            m_networkWriter.Write(outputBuffer);
        }

        

    }

    

    

    // Messages

    // MessageEnum

    public class MessageEnum
    {
        public const int PingPongMessage_id = 1;
        public const int TextMessage_id = 2;
    }

    

    // PingPongMessage

    [Message(MessageEnum.PingPongMessage_id)]
    public class PingPongMessage : Message
    {

        // Constructor

        public PingPongMessage()
        {
        }

        

        // TimeStamp

        private long m_timeStamp = 0;
        public long TimeStamp
        {
            get
            {
                return (m_timeStamp);
            }
            set
            {
                m_timeStamp = value;
            }
        }

        

        // void Read( int size, BinaryReader reader )

        public override void Read(int size, BinaryReader reader)
        {
            m_timeStamp = reader.ReadInt64();
        }

        

        // void Write( BinaryWriter writer )

        public override void Write(BinaryWriter writer)
        {
            writer.Write(m_timeStamp);
        }

        

    }

    

    // TextMessage

    [Message(MessageEnum.TextMessage_id)]
    public class TextMessage : Message
    {

        // Text

        private string m_Text = "";
        public string Text
        {
            get
            {
                return (m_Text);
            }
            set
            {
                m_Text = value;
            }
        }

        

        // void Read( int size, BinaryReader reader )

        public override void Read(int size, BinaryReader reader)
        {
            m_Text = reader.ReadString();
        }

        

        // void Write( BinaryWriter writer )

        public override void Write(BinaryWriter writer)
        {
            writer.Write(m_Text);
        }

        

    }

    

    

    // ZSocket

    public class ZSocket
    {
        // Member Variables

        private Socket m_sock = null;
        private MessageStream m_Stream = null;
        private ArrayList m_IncomingMessageList = null;
        private ArrayList m_OutgoingMessageList = null;

        

        // Constructor

        public ZSocket()
        {
        }

        

        // Accessor Methods
        public ArrayList IncomingMessageList
        {
            get
            {
                if (m_IncomingMessageList == null)
                {
                    m_IncomingMessageList = new ArrayList();
                }
                return m_IncomingMessageList;
            }
        }

        public ArrayList OutgoingMessageList
        {
            get
            {
                if (m_OutgoingMessageList == null)
                {
                    m_OutgoingMessageList = new ArrayList();
                }
                return m_OutgoingMessageList;
            }
        }

        public MessageStream Stream
        {
            get
            {
                if ((m_Stream == null) && (m_sock != null))
                {
                    if (m_sock.Connected)
                        m_Stream = new MessageStream(m_sock);
                }
                return m_Stream;
            }
        }

        public Socket Socket
        {
            get { return m_sock; }
            set { m_sock = value; }
        }
        

        // bool CreateSocket()
        public bool CreateSocket()
        {
            bool retVal = true;
            try
            {
                m_sock = new System.Net.Sockets.Socket(AddressFamily.InterNetwork,
                    SocketType.Stream,
                    ProtocolType.Tcp);
            }
            catch (Exception)
            {
                retVal = false;
            }
            return retVal;
        }
        

        // bool Connect(string host,int port)
        public bool Connect(string host, int port)
        {
            bool retVal = true;
            try
            {
                IPHostEntry IPHost = Dns.Resolve(host);

                string[] aliases = IPHost.Aliases;
                IPAddress[] addr = IPHost.AddressList;
                EndPoint ep = new IPEndPoint(addr[0], port);
                m_sock.Connect(ep);
                retVal = m_sock.Connected;
            }
            catch (Exception)
            {
                retVal = false;
            }
            return retVal;
        }
        

        // void Disconnect()
        public void Disconnect()
        {
            if (m_sock != null)
            {
                try
                {
                    IncomingMessageList.Clear();
                    OutgoingMessageList.Clear();

                    m_sock.Shutdown(SocketShutdown.Both);
                    m_sock.Close();
                }
                catch (Exception) { }
            }
        }
        

        // bool Receive()
        public bool Receive()
        {
            bool bRet = false;

            if (m_sock == null)
                return bRet;

            if (!m_sock.Connected)
                return bRet;

            while (this.Stream.IsMessageAvailable)
            {
                IncomingMessageList.Add(Stream.ReceiveMessage());
                bRet = true;
            }
            return bRet;
        }
        

        // void Send()
        public void Send()
        {
            if (m_sock == null)
                return;

            if (!m_sock.Connected)
                return;

            foreach (Message m in OutgoingMessageList)
            {
                Stream.SendMessage(m);
            }
            OutgoingMessageList.Clear();
        }
        
    }
    

    // ServerManager
    public class ServerManager
    {
        // MemberVariables
        ZSocket m_ListeningSocket = null;
        int m_ListeningPort = 0;
        private bool m_bLocalOnly = false;
        ArrayList m_ConnectedSockets = null;
        ArrayList m_ActiveSockets = null;
        

        // Accessor Methods
        // ArrayList[ZSocket] ConnectedSockets
        public ArrayList ConnectedSockets
        {
            get
            {
                if (m_ConnectedSockets == null)
                {
                    m_ConnectedSockets = new ArrayList();
                }
                return m_ConnectedSockets;
            }
        }
        

        // ArrayList[ZSocket] ActiveSockets
        public ArrayList ActiveSockets
        {
            get
            {
                if (m_ActiveSockets == null)
                {
                    m_ActiveSockets = new ArrayList();
                }
                return m_ActiveSockets;
            }
        }
        

        // Socket
        public ZSocket Socket
        {
            get { return m_ListeningSocket; }
        }
        

        

        // bool PollSockets()
        public bool PollSockets()
        {
            ActiveSockets.Clear();
            foreach (ZSocket s in ConnectedSockets)
            {
                if (s.Receive())
                {
                    ActiveSockets.Add(s);
                }
            }
            return (ActiveSockets.Count > 0);
        }
        

        // void StopServer()
        public void StopServer()
        {

            this.Socket.Disconnect();
            this.Socket.Socket = null;
        }
        

        // void StartServer(int port,bool LocalOnly)
        public void StartServer(int port, bool LocalOnly)
        {
            m_ListeningPort = port;
            m_bLocalOnly = LocalOnly;

            try
            {

                m_ListeningSocket = new ZSocket();
                m_ListeningSocket.CreateSocket();

                if (m_bLocalOnly)
                {
                    m_ListeningSocket.Socket.Bind(new IPEndPoint(IPAddress.Loopback, m_ListeningPort));
                }
                else
                {

                    IPAddress[] aryLocalAddr = null;
                    String strHostName = "";
                    try
                    {
                        // NOTE: DNS lookups are nice and all but quite time consuming.
                        strHostName = Dns.GetHostName();
                        IPHostEntry ipEntry = Dns.GetHostByName(strHostName);
                        aryLocalAddr = ipEntry.AddressList;
                    }
                    catch (Exception ex)
                    {
                        Console.WriteLine("Error trying to get local address {0} ", ex.Message);
                    }

                    // Verify we got an IP address. Tell the user if we did
                    if (aryLocalAddr == null %7C%7C aryLocalAddr.Length < 1)
                    {
                        Console.WriteLine("Unable to get local address");
                        return;
                    }

                    m_ListeningSocket.Socket.Bind(new IPEndPoint(aryLocalAddr[0], m_ListeningPort));
                }

                m_ListeningSocket.Socket.Listen(10);

                // Setup a callback to be notified of connection requests
                m_ListeningSocket.Socket.BeginAccept(
                           new AsyncCallback(OnConnectRequest), m_ListeningSocket);
            }
            catch (Exception ex)
            {
                System.Diagnostics.Debug.WriteLine(ex);
            }
        }
        

        // void OnConnectRequest( IAsyncResult ar )
        private void OnConnectRequest(IAsyncResult ar)
        {
            ZSocket listener = (ZSocket)ar.AsyncState;
            NewConnection(listener.Socket.EndAccept(ar));
            listener.Socket.BeginAccept(new AsyncCallback(OnConnectRequest), listener);
        }
        

        // void NewConnection( Socket sockClient )
        private void NewConnection(Socket sockClient)
        {
            ZSocket zs = new ZSocket();
            zs.Socket = sockClient;
            ConnectedSockets.Add(zs);
        }
        
    }
    ```
Clone this wiki locally