Skip to content

Commit

Permalink
netplay: use WzConnectionProvider to create instances of descriptor…
Browse files Browse the repository at this point in the history
… sets

Get rid of `IDescriptorSet::create()` crutch and move the
code to create IDescriptorSet:s to `WzConnectionProvider`, where
it belongs.

Signed-off-by: Pavel Solodovnikov <pavel.al.solodovnikov@gmail.com>
  • Loading branch information
ManManson authored and past-due committed Feb 23, 2025
1 parent 0b57a36 commit 7502514
Show file tree
Hide file tree
Showing 14 changed files with 95 additions and 90 deletions.
66 changes: 0 additions & 66 deletions lib/netplay/descriptor_set.cpp

This file was deleted.

5 changes: 0 additions & 5 deletions lib/netplay/descriptor_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,6 @@ class IDescriptorSet
{
public:

/// <summary>
/// Factory method to create descriptor sets for a given poll event type (e.g., readable or writable check).
/// </summary>
static std::unique_ptr<IDescriptorSet> create(PollEventType eventType);

virtual ~IDescriptorSet() = default;

virtual bool add(IClientConnection* conn) = 0;
Expand Down
5 changes: 3 additions & 2 deletions lib/netplay/netplay.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1560,8 +1560,9 @@ int NETinit(bool bFirstCall)
NET_InitPlayers(true, true);

ConnectionProviderRegistry::Instance().Register(ConnectionProviderType::TCP_DIRECT);
ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT).initialize();
PendingWritesManager::instance().initialize();
auto& connProvider = ConnectionProviderRegistry::Instance().Get(ConnectionProviderType::TCP_DIRECT);
connProvider.initialize();
PendingWritesManager::instance().initialize(connProvider);

if (bFirstCall)
{
Expand Down
7 changes: 5 additions & 2 deletions lib/netplay/pending_writes_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "lib/framework/frame.h" // for ASSERT
#include "lib/netplay/descriptor_set.h"
#include "lib/netplay/client_connection.h"
#include "lib/netplay/wz_connection_provider.h"

#include <system_error>

Expand All @@ -46,6 +47,7 @@ void PendingWritesManager::deinitialize()
return;
}
wzMutexLock(mtx_);
connProvider_ = nullptr;
stopRequested_ = true;
pendingWrites_.clear();
wzMutexUnlock(mtx_);
Expand All @@ -56,13 +58,14 @@ void PendingWritesManager::deinitialize()
thread_ = nullptr;
}

void PendingWritesManager::initialize()
void PendingWritesManager::initialize(WzConnectionProvider& connProvider)
{
if (thread_ != nullptr)
{
// No-op in case of a repeated `initialize()` call
return;
}
connProvider_ = &connProvider;
stopRequested_ = false;
mtx_ = wzMutexCreate();
sema_ = wzSemaphoreCreate(0);
Expand Down Expand Up @@ -126,7 +129,7 @@ void PendingWritesManager::threadImplFunction()
while (!stopRequested_)
{
static constexpr std::chrono::milliseconds WRITABLE_CHECK_TIMEOUT{ 50 };
static std::unique_ptr<IDescriptorSet> writableSet = IDescriptorSet::create(PollEventType::WRITABLE);
static std::unique_ptr<IDescriptorSet> writableSet = connProvider_->newDescriptorSet(PollEventType::WRITABLE);

// Check if we can write to some connections.
writableSet->clear();
Expand Down
4 changes: 3 additions & 1 deletion lib/netplay/pending_writes_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct WZ_SEMAPHORE;

class IClientConnection;
class IDescriptorSet;
class WzConnectionProvider;

/// This is a wrapper function that acts as a proxy to `PendingWritesManager::threadImplFunction`.
/// Argument is `PendingWritesManager*` cast to `void*`.
Expand Down Expand Up @@ -70,7 +71,7 @@ class PendingWritesManager

static PendingWritesManager& instance();

void initialize();
void initialize(WzConnectionProvider& connProvider);
void deinitialize();

template <typename Fn>
Expand Down Expand Up @@ -141,4 +142,5 @@ class PendingWritesManager
WZ_SEMAPHORE* sema_ = nullptr;
WZ_THREAD* thread_ = nullptr;
bool stopRequested_ = false;
WzConnectionProvider* connProvider_ = nullptr;
};
8 changes: 5 additions & 3 deletions lib/netplay/tcp/tcp_client_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,20 @@
#include "lib/framework/string_ext.h"
#include "lib/netplay/error_categories.h"
#include "lib/netplay/polling_util.h"
#include "lib/netplay/wz_connection_provider.h"
#include "lib/netplay/tcp/tcp_client_connection.h"
#include "lib/netplay/tcp/netsocket.h"
#include "lib/netplay/tcp/sock_error.h"

namespace tcp
{

TCPClientConnection::TCPClientConnection(Socket* rawSocket)
TCPClientConnection::TCPClientConnection(WzConnectionProvider& connProvider, Socket* rawSocket)
: socket_(rawSocket),
selfConnList_({ this }),
readAllDescriptorSet_(IDescriptorSet::create(PollEventType::READABLE)),
connStatusDescriptorSet_(IDescriptorSet::create(PollEventType::READABLE))
connProvider_(&connProvider),
readAllDescriptorSet_(connProvider_->newDescriptorSet(PollEventType::READABLE)),
connStatusDescriptorSet_(connProvider_->newDescriptorSet(PollEventType::READABLE))
{
ASSERT(socket_ != nullptr, "Null socket passed to TCPClientConnection ctor");
}
Expand Down
5 changes: 4 additions & 1 deletion lib/netplay/tcp/tcp_client_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "lib/netplay/descriptor_set.h"
#include "lib/netplay/tcp/netsocket.h" // for SOCKET

class WzConnectionProvider;

namespace tcp
{

Expand All @@ -32,7 +34,7 @@ class TCPClientConnection : public IClientConnection
{
public:

explicit TCPClientConnection(Socket* rawSocket);
explicit TCPClientConnection(WzConnectionProvider& connProvider, Socket* rawSocket);
virtual ~TCPClientConnection() override;

virtual net::result<ssize_t> readAll(void* buf, size_t size, unsigned timeout) override;
Expand Down Expand Up @@ -63,6 +65,7 @@ class TCPClientConnection : public IClientConnection
// (like `readAll()` and `connectionStatus()`) to avoid extra
// memory allocations.
const std::vector<IClientConnection*> selfConnList_;
WzConnectionProvider* connProvider_;
std::unique_ptr<IDescriptorSet> readAllDescriptorSet_;
std::unique_ptr<IDescriptorSet> connStatusDescriptorSet_;
};
Expand Down
6 changes: 4 additions & 2 deletions lib/netplay/tcp/tcp_connection_poll_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "lib/netplay/tcp/tcp_client_connection.h"
#include "lib/netplay/tcp/netsocket.h"
#include "lib/netplay/polling_util.h"
#include "lib/netplay/wz_connection_provider.h"
#include "lib/framework/wzapp.h"
#include "lib/framework/debug.h"

Expand All @@ -29,8 +30,9 @@
namespace tcp
{

TCPConnectionPollGroup::TCPConnectionPollGroup()
: readableSet_(IDescriptorSet::create(PollEventType::READABLE))
TCPConnectionPollGroup::TCPConnectionPollGroup(WzConnectionProvider& connProvider)
: connProvider_(&connProvider),
readableSet_(connProvider_->newDescriptorSet(PollEventType::READABLE))
{}

net::result<int> TCPConnectionPollGroup::checkConnectionsReadable(std::chrono::milliseconds timeout)
Expand Down
4 changes: 3 additions & 1 deletion lib/netplay/tcp/tcp_connection_poll_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

class IClientConnection;
class IDescriptorSet;
class WzConnectionProvider;

namespace tcp
{
Expand All @@ -34,7 +35,7 @@ class TCPConnectionPollGroup : public IConnectionPollGroup
{
public:

explicit TCPConnectionPollGroup();
explicit TCPConnectionPollGroup(WzConnectionProvider& connProvider);
virtual ~TCPConnectionPollGroup() override = default;

virtual net::result<int> checkConnectionsReadable(std::chrono::milliseconds timeout) override;
Expand All @@ -44,6 +45,7 @@ class TCPConnectionPollGroup : public IConnectionPollGroup
private:

std::vector<IClientConnection*> conns_;
WzConnectionProvider* connProvider_;
// Pre-allocated descriptor set for `checkConnectionsReadable` operation
// to avoid extra memory allocations.
std::unique_ptr<IDescriptorSet> readableSet_;
Expand Down
48 changes: 45 additions & 3 deletions lib/netplay/tcp/tcp_connection_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
#include "lib/netplay/open_connection_result.h"
#include "lib/framework/wzapp.h"

#ifdef WZ_OS_WIN
# include "lib/netplay/tcp/select_descriptor_set.h"
#else
# include "lib/netplay/tcp/poll_descriptor_set.h"
#endif

namespace tcp
{

Expand Down Expand Up @@ -58,7 +64,7 @@ net::result<IListenSocket*> TCPConnectionProvider::openListenSocket(uint16_t por
{
return tl::make_unexpected(res.error());
}
return new TCPListenSocket(res.value());
return new TCPListenSocket(*this, res.value());
}

net::result<IClientConnection*> TCPConnectionProvider::openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout)
Expand All @@ -75,12 +81,48 @@ net::result<IClientConnection*> TCPConnectionProvider::openClientConnectionAny(c
{
return tl::make_unexpected(res.error());
}
return new TCPClientConnection(res.value());
return new TCPClientConnection(*this, res.value());
}

IConnectionPollGroup* TCPConnectionProvider::newConnectionPollGroup()
{
return new TCPConnectionPollGroup();
return new TCPConnectionPollGroup(*this);
}

std::unique_ptr<IDescriptorSet> TCPConnectionProvider::newDescriptorSet(PollEventType eventType)
{
// For now, use `select()` on Windows instead of `poll()` because of a bug in
// Windows versions prior to "Windows 10 2004", which can lead to `poll()`
// function timing out on socket connection errors instead of returning an error early.
//
// For more information on the bug, see: https://stackoverflow.com/questions/21653003/is-this-wsapoll-bug-for-non-blocking-sockets-fixed
// and also https://learn.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsapoll#remarks
switch (eventType)
{
case PollEventType::READABLE:
{
#ifdef WZ_OS_WIN
IDescriptorSet* rawPtr = new tcp::SelectDescriptorSet<PollEventType::READABLE>();
return std::unique_ptr<IDescriptorSet>(rawPtr);
#else
IDescriptorSet* rawPtr = new tcp::PollDescriptorSet<PollEventType::READABLE>();
return std::unique_ptr<IDescriptorSet>(rawPtr);
#endif
}
case PollEventType::WRITABLE:
{
#ifdef WZ_OS_WIN
IDescriptorSet* rawPtr = new tcp::SelectDescriptorSet<PollEventType::WRITABLE>();
return std::unique_ptr<IDescriptorSet>(rawPtr);
#else
IDescriptorSet* rawPtr = new tcp::PollDescriptorSet<PollEventType::WRITABLE>();
return std::unique_ptr<IDescriptorSet>(rawPtr);
#endif
}
default:
ASSERT(false, "Unexpected PollEventType value: %d", static_cast<int>(eventType));
return nullptr;
}
}

} // namespace tcp
2 changes: 2 additions & 0 deletions lib/netplay/tcp/tcp_connection_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class TCPConnectionProvider final : public WzConnectionProvider
virtual net::result<IClientConnection*> openClientConnectionAny(const IConnectionAddress& addr, unsigned timeout) override;

virtual IConnectionPollGroup* newConnectionPollGroup() override;

virtual std::unique_ptr<IDescriptorSet> newDescriptorSet(PollEventType eventType) override;
};

} // namespace tcp
7 changes: 4 additions & 3 deletions lib/netplay/tcp/tcp_listen_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
namespace tcp
{

TCPListenSocket::TCPListenSocket(Socket* rawSocket)
: listenSocket_(rawSocket)
TCPListenSocket::TCPListenSocket(WzConnectionProvider& connProvider, Socket* rawSocket)
: listenSocket_(rawSocket),
connProvider_(&connProvider)
{}

TCPListenSocket::~TCPListenSocket()
Expand All @@ -50,7 +51,7 @@ IClientConnection* TCPListenSocket::accept()
{
return nullptr;
}
return new TCPClientConnection(s);
return new TCPClientConnection(*connProvider_, s);
}

IListenSocket::IPVersionsMask TCPListenSocket::supportedIpVersions() const
Expand Down
5 changes: 4 additions & 1 deletion lib/netplay/tcp/tcp_listen_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include "lib/netplay/listen_socket.h"

class WzConnectionProvider;

namespace tcp
{

Expand All @@ -32,7 +34,7 @@ class TCPListenSocket : public IListenSocket
{
public:

explicit TCPListenSocket(tcp::Socket* rawSocket);
explicit TCPListenSocket(WzConnectionProvider& connProvider, tcp::Socket* rawSocket);
virtual ~TCPListenSocket() override;

virtual IClientConnection* accept() override;
Expand All @@ -41,6 +43,7 @@ class TCPListenSocket : public IListenSocket
private:

tcp::Socket* listenSocket_;
WzConnectionProvider* connProvider_;
};

} // namespace tcp
Loading

0 comments on commit 7502514

Please sign in to comment.