Browse Source

Add accept method to Socket class

Patrick-Christopher Mattulat 1 năm trước cách đây
mục cha
commit
17279c9134

+ 1 - 0
include/ls_std/core/interface/IPosixSocket.hpp

@@ -20,6 +20,7 @@ namespace ls::std::core::interface_type
 
       virtual ~IPosixSocket() = default;
 
+      virtual int accept(int _socketFileDescriptor, struct sockaddr *_address, socklen_t* _addressLength) = 0;
       virtual int bind(int _socketFileDescriptor, const struct sockaddr* _address, socklen_t _addressLength) = 0;
       virtual int close(int _socketFileDescriptor) = 0;
       virtual int connect(int _socketFileDescriptor, const struct sockaddr* _address, socklen_t _addressLength) = 0;

+ 2 - 1
include/ls_std/network/socket/ConvertedSocketAddress.hpp

@@ -3,7 +3,7 @@
  * Company:         Lynar Studios
  * E-Mail:          webmaster@lynarstudios.com
  * Created:         2022-11-18
- * Changed:         2022-11-18
+ * Changed:         2022-12-12
  *
  * */
 
@@ -20,6 +20,7 @@ namespace ls::std::network
   {
     #if defined(unix) || defined(__APPLE__)
     ::sockaddr_in socketAddressUnix{};
+    ::socklen_t addressLength{};
     #endif
   };
 }

+ 1 - 0
include/ls_std/network/socket/MockPosixSocket.hpp

@@ -22,6 +22,7 @@ namespace ls_std_network_test
       MockPosixSocket() = default;
       ~MockPosixSocket() override = default;
 
+      MOCK_METHOD(int, accept, (int _socketFileDescriptor, struct sockaddr *_address, socklen_t* _addressLength), (override));
       MOCK_METHOD(int, bind, (int _socketFileDescriptor, const struct sockaddr *_address, socklen_t _addressLength), (override));
       MOCK_METHOD(int, close, (int _socketFileDescriptor), (override));
       MOCK_METHOD(int, connect, (int _socketFileDescriptor, const struct sockaddr *_address, socklen_t _addressLength), (override));

+ 1 - 0
include/ls_std/network/socket/PosixSocket.hpp

@@ -21,6 +21,7 @@ namespace ls::std::network
       PosixSocket() = default;
       ~PosixSocket() override = default;
 
+      int accept(int _socketFileDescriptor, struct sockaddr *_address, socklen_t* _addressLength) override;
       int bind(int _socketFileDescriptor, const struct sockaddr *_address, socklen_t _addressLength) override;
       int close(int _socketFileDescriptor) override;
       int connect(int _socketFileDescriptor, const struct sockaddr *_address, socklen_t _addressLength) override;

+ 2 - 0
include/ls_std/network/socket/Socket.hpp

@@ -28,6 +28,7 @@ namespace ls::std::network
       explicit Socket(ls::std::network::SocketParameter _parameter);
       ~Socket() override = default;
 
+      [[nodiscard]] bool accept();
       [[nodiscard]] bool bind();
       [[nodiscard]] bool close();
       [[nodiscard]] bool connect();
@@ -43,6 +44,7 @@ namespace ls::std::network
       #endif
 
       #if defined(unix) || defined(__APPLE__)
+      [[nodiscard]] bool _acceptUnix();
       [[nodiscard]] bool _bindUnix();
       [[nodiscard]] bool _closeUnix();
       [[nodiscard]] bool _connectUnix();

+ 5 - 0
source/ls_std/network/socket/PosixSocket.cpp

@@ -11,6 +11,11 @@
 #include <sys/socket.h>
 #include <unistd.h>
 
+int ls::std::network::PosixSocket::accept(int _socketFileDescriptor, struct sockaddr *_address, socklen_t* _addressLength)
+{
+  return ::accept(_socketFileDescriptor, _address, _addressLength);
+}
+
 int ls::std::network::PosixSocket::bind(int _socketFileDescriptor, const struct sockaddr *_address, socklen_t _addressLength)
 {
   return ::bind(_socketFileDescriptor, _address, _addressLength);

+ 21 - 2
source/ls_std/network/socket/Socket.cpp

@@ -27,6 +27,18 @@ parameter(::std::move(_parameter))
   #endif
 }
 
+bool ls::std::network::Socket::accept()
+{
+  if (this->parameter.socketAddress.protocolType != PROTOCOL_TYPE_TCP)
+  {
+    throw ls::std::core::WrongProtocolException{};
+  }
+
+  #if defined(unix) || defined(__APPLE__)
+  return ls::std::network::Socket::_acceptUnix();
+  #endif
+}
+
 bool ls::std::network::Socket::bind()
 {
   #if defined(unix) || defined(__APPLE__)
@@ -66,10 +78,17 @@ bool ls::std::network::Socket::listen()
 }
 
 #if defined(unix) || defined(__APPLE__)
+
+bool ls::std::network::Socket::_acceptUnix()
+{
+  ls::std::network::ConvertedSocketAddress convertedSocketAddress = ls::std::network::SocketAddressMapper::from(ls::std::network::Socket::_createSocketAddressMapperParameter());
+  return this->parameter.posixSocket->accept(this->unixDescriptor, reinterpret_cast<sockaddr *>(&convertedSocketAddress.socketAddressUnix), &convertedSocketAddress.addressLength) >= 0;
+}
+
 bool ls::std::network::Socket::_bindUnix()
 {
   ls::std::network::ConvertedSocketAddress convertedSocketAddress = ls::std::network::SocketAddressMapper::from(ls::std::network::Socket::_createSocketAddressMapperParameter());
-  return this->parameter.posixSocket->bind(this->unixDescriptor, reinterpret_cast<const sockaddr *>(&convertedSocketAddress.socketAddressUnix), sizeof(convertedSocketAddress.socketAddressUnix)) == 0;
+  return this->parameter.posixSocket->bind(this->unixDescriptor, reinterpret_cast<const sockaddr *>(&convertedSocketAddress.socketAddressUnix), convertedSocketAddress.addressLength) == 0;
 }
 
 bool ls::std::network::Socket::_closeUnix()
@@ -80,7 +99,7 @@ bool ls::std::network::Socket::_closeUnix()
 bool ls::std::network::Socket::_connectUnix()
 {
   ls::std::network::ConvertedSocketAddress convertedSocketAddress = ls::std::network::SocketAddressMapper::from(ls::std::network::Socket::_createSocketAddressMapperParameter());
-  return this->parameter.posixSocket->connect(this->unixDescriptor, reinterpret_cast<const sockaddr *>(&convertedSocketAddress.socketAddressUnix), sizeof(convertedSocketAddress.socketAddressUnix)) == 0;
+  return this->parameter.posixSocket->connect(this->unixDescriptor, reinterpret_cast<const sockaddr *>(&convertedSocketAddress.socketAddressUnix), convertedSocketAddress.addressLength) == 0;
 }
 #endif
 

+ 2 - 1
source/ls_std/network/socket/SocketAddressMapper.cpp

@@ -3,7 +3,7 @@
  * Company:         Lynar Studios
  * E-Mail:          webmaster@lynarstudios.com
  * Created:         2022-11-18
- * Changed:         2022-12-09
+ * Changed:         2022-12-12
  *
  * */
 
@@ -24,6 +24,7 @@ ls::std::network::ConvertedSocketAddress ls::std::network::SocketAddressMapper::
 
   #if defined(unix) || defined(__APPLE__)
   convertedSocketAddress.socketAddressUnix = ls::std::network::SocketAddressMapper::_toSockAddressUnix(_parameter);
+  convertedSocketAddress.addressLength = sizeof(convertedSocketAddress.socketAddressUnix);
   #endif
 
   return convertedSocketAddress;

+ 2 - 1
test/cases/network/socket/SocketAddressMapperTest.cpp

@@ -3,7 +3,7 @@
  * Company:         Lynar Studios
  * E-Mail:          webmaster@lynarstudios.com
  * Created:         2020-11-18
- * Changed:         2022-12-09
+ * Changed:         2022-12-12
  *
  * */
 
@@ -59,6 +59,7 @@ namespace
     ASSERT_EQ(36895, convertedSocketAddress.socketAddressUnix.sin_port); // expected: return value of htons()
     ASSERT_EQ(AF_INET, convertedSocketAddress.socketAddressUnix.sin_family);
     ASSERT_EQ(16777343, convertedSocketAddress.socketAddressUnix.sin_addr.s_addr); // expected: return value of inet_aton()
+    ASSERT_EQ(16, convertedSocketAddress.addressLength);
     #endif
   }
 }

+ 45 - 0
test/cases/network/socket/SocketTest.cpp

@@ -51,6 +51,51 @@ namespace
     ASSERT_STREQ("Socket", Socket{generateSocketParameter()}.getClassName().c_str());
   }
 
+  TEST_F(SocketTest, accept)
+  {
+    SocketParameter parameter = generateSocketParameter();
+
+    #if defined(unix) || defined(__APPLE__)
+    shared_ptr<MockPosixSocket> mockSocket = make_shared<MockPosixSocket>();
+    parameter.posixSocket = mockSocket;
+
+    EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, accept(_, _, _)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, accept(_, _, _)).WillByDefault(Return(0));
+    #endif
+
+    Socket socket{parameter};
+    ASSERT_TRUE(socket.accept());
+  }
+
+  TEST_F(SocketTest, accept_wrong_protocol)
+  {
+    SocketParameter parameter = generateSocketParameter();
+    parameter.socketAddress.protocolType = PROTOCOL_TYPE_UDP;
+
+    #if defined(unix) || defined(__APPLE__)
+    shared_ptr<MockPosixSocket> mockSocket = make_shared<MockPosixSocket>();
+    parameter.posixSocket = mockSocket;
+
+    EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    #endif
+
+    Socket socket{parameter};
+
+    EXPECT_THROW({
+                   try
+                   {
+                     bool listened = socket.accept();
+                   }
+                   catch (const WrongProtocolException &_exception)
+                   {
+                     throw;
+                   }
+                 }, WrongProtocolException);
+  }
+
   TEST_F(SocketTest, bind)
   {
     SocketParameter parameter = generateSocketParameter();