浏览代码

Fix accept method of Socket class

- method now returns a connection ID, which can be used to reference it
for read and write operations
- added socket close calls to Socket class destructor
Patrick-Christopher Mattulat 2 年之前
父节点
当前提交
dcb4fc11ef

+ 3 - 2
include/ls_std/core/types/SocketTypes.hpp

@@ -3,7 +3,7 @@
  * Company:         Lynar Studios
  * E-Mail:          webmaster@lynarstudios.com
  * Created:         2020-11-16
- * Changed:         2022-11-16
+ * Changed:         2022-12-27
  *
  * */
 
@@ -14,8 +14,9 @@
 
 namespace ls::std::core::type
 {
-  using port = unsigned short;
+  using connection_id = int;
   using ip_address = ::std::string;
+  using port = unsigned short;
 }
 
 #endif

+ 21 - 7
include/ls_std/network/socket/Socket.hpp

@@ -3,7 +3,7 @@
  * Company:         Lynar Studios
  * E-Mail:          webmaster@lynarstudios.com
  * Created:         2022-11-16
- * Changed:         2022-12-26
+ * Changed:         2022-12-27
  *
  * */
 
@@ -19,6 +19,8 @@
 #include <ls_std/core/types/Types.hpp>
 #include <ls_std/core/interface/IReader.hpp>
 #include <ls_std/core/interface/IWriter.hpp>
+#include <unordered_map>
+#include <ls_std/core/types/SocketTypes.hpp>
 
 namespace ls::std::network
 {
@@ -31,35 +33,47 @@ namespace ls::std::network
 
       // implementation
 
-      [[nodiscard]] ls::std::core::type::byte_field read() override;
-      [[nodiscard]] bool write(const ls::std::core::type::byte_field &_data) override;
+      [[nodiscard]] ls::std::core::type::byte_field read() override; // TODO: set write descriptor
+      [[nodiscard]] bool write(const ls::std::core::type::byte_field &_data) override; // TODO: set write descriptor
 
       // other functionalities
 
-      [[nodiscard]] bool accept();
+      [[nodiscard]] ls::std::core::type::connection_id accept();
       [[nodiscard]] bool bind();
       [[nodiscard]] bool close();
       [[nodiscard]] bool connect();
+      [[nodiscard]] bool handle(const ls::std::core::type::connection_id& _acceptedConnectionId);
+      // void handle() for this Socket
       [[nodiscard]] bool isInitialized() const;
       [[nodiscard]] bool listen();
 
     private:
 
+      ls::std::core::type::connection_id currentAcceptedConnection{};
       bool initialized{};
       ls::std::network::SocketParameter parameter{};
       ls::std::core::type::byte* readBuffer{};
       bool readBufferSet{};
       #if LS_STD_UNIX_PLATFORM
-      int unixDescriptor{};
+      ::std::unordered_map<ls::std::core::type::connection_id, int> unixAcceptDescriptors{}; // TODO: provide a struct with connection information
+      int unixDescriptor{}; // TODO: also introduce "handle" usage - must be possible to set it for read and write
+      ls::std::core::type::connection_id unixUniqueDescriptorId{};
       #endif
 
       #if LS_STD_UNIX_PLATFORM
-      [[nodiscard]] bool _acceptUnix();
+      [[nodiscard]] ls::std::core::type::connection_id _acceptUnix();
       [[nodiscard]] bool _bindUnix();
-      [[nodiscard]] bool _closeUnix();
+      #endif
+      [[nodiscard]] bool _close();
+      #if LS_STD_UNIX_PLATFORM
+      [[nodiscard]] bool _closeUnix(const int& _descriptor);
       [[nodiscard]] bool _connectUnix();
       #endif
       [[nodiscard]] SocketAddressMapperParameter _createSocketAddressMapperParameter() const;
+      [[nodiscard]] bool _hasAcceptedConnection(const ls::std::core::type::connection_id& _connectionId);
+      #if LS_STD_UNIX_PLATFORM
+      [[nodiscard]] bool _hasAcceptedConnectionUnix(const ls::std::core::type::connection_id& _connectionId);
+      #endif
       void _init();
       void _initReadBuffer();
       #if LS_STD_UNIX_PLATFORM

+ 73 - 9
source/ls_std/network/socket/Socket.cpp

@@ -20,6 +20,8 @@
 #include <ls_std/core/exception/IllegalArgumentException.hpp>
 #include <ls_std/core/exception/FileOperationException.hpp>
 #include <memory>
+#include <ls_std/core/exception/SocketOperationFailedException.hpp>
+#include <iostream>
 
 ls::std::network::Socket::Socket(ls::std::network::SocketParameter _parameter) : ls::std::core::Class("Socket"),
 parameter(::std::move(_parameter))
@@ -30,6 +32,21 @@ parameter(::std::move(_parameter))
 ls::std::network::Socket::~Socket()
 {
   delete[] this->readBuffer;
+
+  #if LS_STD_UNIX_PLATFORM
+  for (const auto& connection : this->unixAcceptDescriptors)
+  {
+    if (!this->_closeUnix(connection.second))
+    {
+      ::std::cerr << "could not close socket with id \"" << connection.first << "\"" << ::std::endl;
+    }
+  }
+
+  if (!this->_closeUnix(this->unixDescriptor))
+  {
+    ::std::cerr << "could not close socket with descriptor \"" << this->unixDescriptor << "\"" << ::std::endl;
+  }
+  #endif
 }
 
 ls::std::core::type::byte_field ls::std::network::Socket::read()
@@ -48,7 +65,7 @@ bool ls::std::network::Socket::write(const ls::std::core::type::byte_field &_dat
   return this->_write(_data);
 }
 
-bool ls::std::network::Socket::accept()
+ls::std::core::type::connection_id ls::std::network::Socket::accept()
 {
   if (this->parameter.socketAddress.protocolType != PROTOCOL_TYPE_TCP)
   {
@@ -69,9 +86,7 @@ bool ls::std::network::Socket::bind()
 
 bool ls::std::network::Socket::close()
 {
-  #if LS_STD_UNIX_PLATFORM
-  return ls::std::network::Socket::_closeUnix();
-  #endif
+  return this->_close();
 }
 
 bool ls::std::network::Socket::connect()
@@ -81,6 +96,19 @@ bool ls::std::network::Socket::connect()
   #endif
 }
 
+bool ls::std::network::Socket::handle(const ls::std::core::type::connection_id &_acceptedConnectionId)
+{
+  bool focusSet{};
+
+  if (this->_hasAcceptedConnection(_acceptedConnectionId))
+  {
+    this->currentAcceptedConnection = _acceptedConnectionId;
+    focusSet = true;
+  }
+
+  return focusSet;
+}
+
 bool ls::std::network::Socket::isInitialized() const
 {
   return this->initialized;
@@ -99,10 +127,23 @@ bool ls::std::network::Socket::listen()
 }
 
 #if LS_STD_UNIX_PLATFORM
-bool ls::std::network::Socket::_acceptUnix()
+ls::std::core::type::connection_id ls::std::network::Socket::_acceptUnix()
 {
-  ls::std::network::ConvertedSocketAddress convertedSocketAddress = ls::std::network::SocketAddressMapper::from(ls::std::network::Socket::_createSocketAddressMapperParameter());
-  return this->parameter.posixSocket->accept(this->unixDescriptor, reinterpret_cast<sockaddr *>(&convertedSocketAddress.socketAddressUnix), &convertedSocketAddress.addressLength) >= 0;
+  ::sockaddr_in incoming{};
+  ::socklen_t length{};
+  ls::std::core::type::connection_id acceptedDescriptor = this->parameter.posixSocket->accept(this->unixDescriptor, reinterpret_cast<sockaddr *>(&incoming), &length);
+
+  if (acceptedDescriptor >= 0)
+  {
+    ++this->unixUniqueDescriptorId;
+    this->unixAcceptDescriptors.insert({this->unixUniqueDescriptorId, acceptedDescriptor});
+  }
+  else
+  {
+    throw ls::std::core::SocketOperationFailedException{};
+  }
+
+  return this->unixUniqueDescriptorId;
 }
 
 bool ls::std::network::Socket::_bindUnix()
@@ -110,10 +151,19 @@ bool ls::std::network::Socket::_bindUnix()
   ls::std::network::ConvertedSocketAddress convertedSocketAddress = ls::std::network::SocketAddressMapper::from(ls::std::network::Socket::_createSocketAddressMapperParameter());
   return this->parameter.posixSocket->bind(this->unixDescriptor, reinterpret_cast<const sockaddr *>(&convertedSocketAddress.socketAddressUnix), convertedSocketAddress.addressLength) == 0;
 }
+#endif
+
+bool ls::std::network::Socket::_close()
+{
+  #if LS_STD_UNIX_PLATFORM
+  return ls::std::network::Socket::_closeUnix(this->unixDescriptor);
+  #endif
+}
 
-bool ls::std::network::Socket::_closeUnix()
+#if LS_STD_UNIX_PLATFORM
+bool ls::std::network::Socket::_closeUnix(const int& _descriptor)
 {
-  return this->parameter.posixSocket->close(this->unixDescriptor) == 0;
+  return this->parameter.posixSocket->close(_descriptor) == 0;
 }
 
 bool ls::std::network::Socket::_connectUnix()
@@ -132,6 +182,20 @@ ls::std::network::SocketAddressMapperParameter ls::std::network::Socket::_create
   return mapperParameter;
 }
 
+bool ls::std::network::Socket::_hasAcceptedConnection(const ls::std::core::type::connection_id &_connectionId)
+{
+  #if LS_STD_UNIX_PLATFORM
+  return this->_hasAcceptedConnectionUnix(_connectionId);
+  #endif
+}
+
+#if LS_STD_UNIX_PLATFORM
+bool ls::std::network::Socket::_hasAcceptedConnectionUnix(const ls::std::core::type::connection_id &_connectionId)
+{
+  return this->unixAcceptDescriptors.find(_connectionId) != this->unixAcceptDescriptors.end();
+}
+#endif
+
 void ls::std::network::Socket::_init()
 {
   #if LS_STD_UNIX_PLATFORM

+ 61 - 2
test/cases/network/socket/SocketTest.cpp

@@ -3,7 +3,7 @@
  * Company:         Lynar Studios
  * E-Mail:          webmaster@lynarstudios.com
  * Created:         2020-11-16
- * Changed:         2022-12-26
+ * Changed:         2022-12-27
  *
  * */
 
@@ -71,6 +71,8 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockReader, read(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockReader, read(_, _, _)).WillByDefault(Return(1));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     parameter.readBufferSize = 32;
@@ -89,6 +91,8 @@ namespace
 
     EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
@@ -119,6 +123,8 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockReader, read(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockReader, read(_, _, _)).WillByDefault(Return(-1));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     parameter.readBufferSize = 32;
@@ -150,6 +156,8 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockWriter, write(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockWriter, write(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
@@ -157,7 +165,7 @@ namespace
     ASSERT_TRUE(socket.write("Hello Server!"));
   }
 
-  TEST_F(SocketTest, accept)
+  TEST_F(SocketTest, accept) // TODO: adjust accept tests due to signature change
   {
     SocketParameter parameter = generateSocketParameter();
 
@@ -169,6 +177,8 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockSocket, accept(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, accept(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
@@ -186,6 +196,8 @@ namespace
 
     EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
@@ -214,6 +226,8 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockSocket, bind(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, bind(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
@@ -250,12 +264,53 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockSocket, connect(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, connect(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
     ASSERT_TRUE(socket.connect());
   }
 
+  TEST_F(SocketTest, handle)
+  {
+    SocketParameter parameter = generateSocketParameter();
+
+    #if LS_STD_UNIX_PLATFORM
+    shared_ptr<MockPosixSocket> mockSocket = make_shared<MockPosixSocket>();
+    parameter.posixSocket = mockSocket;
+
+    EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, accept(_, _, _)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, accept(_, _, _)).WillByDefault(Return(5));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
+    #endif
+
+    Socket socket{parameter};
+    connection_id acceptedConnection = socket.accept();
+    ASSERT_TRUE(socket.handle(acceptedConnection));
+  }
+
+  TEST_F(SocketTest, handle_no_accepted_connection)
+  {
+    SocketParameter parameter = generateSocketParameter();
+
+    #if LS_STD_UNIX_PLATFORM
+    shared_ptr<MockPosixSocket> mockSocket = make_shared<MockPosixSocket>();
+    parameter.posixSocket = mockSocket;
+
+    EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
+    #endif
+
+    Socket socket{parameter};
+    ASSERT_FALSE(socket.handle(13));
+  }
+
   TEST_F(SocketTest, isInitialized)
   {
     Socket socket{generateSocketParameter()};
@@ -274,6 +329,8 @@ namespace
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
     EXPECT_CALL(*mockSocket, listen(_, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, listen(_, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};
@@ -291,6 +348,8 @@ namespace
 
     EXPECT_CALL(*mockSocket, create(_, _, _)).Times(AtLeast(1));
     ON_CALL(*mockSocket, create(_, _, _)).WillByDefault(Return(0));
+    EXPECT_CALL(*mockSocket, close(_)).Times(AtLeast(1));
+    ON_CALL(*mockSocket, close(_)).WillByDefault(Return(0));
     #endif
 
     Socket socket{parameter};