mirror of
https://github.com/shadps4-emu/shadPS4.git
synced 2026-03-28 14:39:43 -06:00
Net fixes (#3895)
* support for flag in recv/send * make kalapsofos happy * on more SendMessage try * ReceiveMessage too
This commit is contained in:
parent
35da0506b6
commit
256397aa3b
@ -1,4 +1,4 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
|
||||
// SPDX-FileCopyrightText: Copyright 2024-2026 shadPS4 Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#pragma once
|
||||
@ -110,6 +110,15 @@ enum OrbisNetSocketSoOption : u32 {
|
||||
ORBIS_NET_SO_PRIORITY = 0x1203
|
||||
};
|
||||
|
||||
enum OrbisNetFlags : u32 {
|
||||
ORBIS_NET_MSG_PEEK = 0x00000002,
|
||||
ORBIS_NET_MSG_WAITALL = 0x00000040,
|
||||
ORBIS_NET_MSG_DONTWAIT = 0x00000080,
|
||||
ORBIS_NET_MSG_USECRYPTO = 0x00100000,
|
||||
ORBIS_NET_MSG_USESIGNATURE = 0x00200000,
|
||||
ORBIS_NET_MSG_PEEKLEN = (0x00400000 | ORBIS_NET_MSG_PEEK)
|
||||
};
|
||||
|
||||
constexpr std::string_view NameOf(OrbisNetSocketSoOption o) {
|
||||
switch (o) {
|
||||
case ORBIS_NET_SO_REUSEADDR:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
|
||||
// SPDX-FileCopyrightText: Copyright 2024-2026 shadPS4 Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#include <vector>
|
||||
@ -184,28 +184,103 @@ int PosixSocket::Listen(int backlog) {
|
||||
return ConvertReturnErrorCode(::listen(sock, backlog));
|
||||
}
|
||||
|
||||
static int convertOrbisFlagsToPosix(int sock_type, int sce_flags) {
|
||||
int posix_flags = 0;
|
||||
|
||||
if (sce_flags & ORBIS_NET_MSG_PEEK)
|
||||
posix_flags |= MSG_PEEK;
|
||||
#ifndef _WIN32
|
||||
if (sce_flags & ORBIS_NET_MSG_DONTWAIT)
|
||||
posix_flags |= MSG_DONTWAIT;
|
||||
#endif
|
||||
// MSG_WAITALL is only valid for stream sockets
|
||||
if ((sce_flags & ORBIS_NET_MSG_WAITALL) &&
|
||||
((sock_type == ORBIS_NET_SOCK_STREAM) || (sock_type == ORBIS_NET_SOCK_STREAM_P2P)))
|
||||
posix_flags |= MSG_WAITALL;
|
||||
|
||||
return posix_flags;
|
||||
}
|
||||
|
||||
// On Windows, MSG_DONTWAIT is not handled natively by recv/send.
|
||||
// This function uses select() with zero timeout to simulate non-blocking behavior.
|
||||
static int socket_is_ready(int sock, bool is_read = true) {
|
||||
fd_set fds;
|
||||
FD_ZERO(&fds);
|
||||
FD_SET(sock, &fds);
|
||||
timeval timeout{0, 0};
|
||||
int res =
|
||||
select(sock + 1, is_read ? &fds : nullptr, is_read ? nullptr : &fds, nullptr, &timeout);
|
||||
if (res == 0)
|
||||
return ORBIS_NET_ERROR_EWOULDBLOCK;
|
||||
else if (res < 0)
|
||||
return ConvertReturnErrorCode(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int PosixSocket::SendMessage(const OrbisNetMsghdr* msg, int flags) {
|
||||
std::scoped_lock lock{m_mutex};
|
||||
|
||||
#ifdef _WIN32
|
||||
DWORD bytesSent = 0;
|
||||
LPFN_WSASENDMSG wsasendmsg = nullptr;
|
||||
GUID guid = WSAID_WSASENDMSG;
|
||||
DWORD bytes = 0;
|
||||
int totalSent = 0;
|
||||
bool waitAll = (flags & ORBIS_NET_MSG_WAITALL) != 0;
|
||||
bool dontWait = (flags & ORBIS_NET_MSG_DONTWAIT) != 0;
|
||||
|
||||
if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid), &wsasendmsg,
|
||||
sizeof(wsasendmsg), &bytes, nullptr, nullptr) != 0) {
|
||||
return ConvertReturnErrorCode(-1);
|
||||
// stream socket with multiple buffers
|
||||
bool use_wsamsg =
|
||||
(socket_type == ORBIS_NET_SOCK_STREAM || socket_type == ORBIS_NET_SOCK_STREAM_P2P) &&
|
||||
msg->msg_iovlen > 1;
|
||||
|
||||
for (int i = 0; i < msg->msg_iovlen; ++i) {
|
||||
char* buf = (char*)msg->msg_iov[i].iov_base;
|
||||
size_t remaining = msg->msg_iov[i].iov_len;
|
||||
|
||||
while (remaining > 0) {
|
||||
if (dontWait) {
|
||||
int ready = socket_is_ready(sock, false);
|
||||
if (ready <= 0)
|
||||
return ready;
|
||||
}
|
||||
|
||||
int sent = 0;
|
||||
if (use_wsamsg) {
|
||||
// only call WSASendMsg if we have multiple buffers
|
||||
LPFN_WSASENDMSG wsasendmsg = nullptr;
|
||||
GUID guid = WSAID_WSASENDMSG;
|
||||
DWORD bytes = 0;
|
||||
if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
|
||||
&wsasendmsg, sizeof(wsasendmsg), &bytes, nullptr, nullptr) != 0) {
|
||||
// fallback to send()
|
||||
sent = ::send(sock, buf, remaining, 0);
|
||||
} else {
|
||||
DWORD bytesSent = 0;
|
||||
int res = wsasendmsg(
|
||||
sock, reinterpret_cast<LPWSAMSG>(const_cast<OrbisNetMsghdr*>(msg)), 0,
|
||||
&bytesSent, nullptr, nullptr);
|
||||
if (res == SOCKET_ERROR)
|
||||
return ConvertReturnErrorCode(WSAGetLastError());
|
||||
sent = bytesSent;
|
||||
}
|
||||
} else {
|
||||
sent = ::send(sock, buf, remaining, 0);
|
||||
if (sent == SOCKET_ERROR)
|
||||
return ConvertReturnErrorCode(WSAGetLastError());
|
||||
}
|
||||
|
||||
totalSent += sent;
|
||||
remaining -= sent;
|
||||
buf += sent;
|
||||
|
||||
if (!waitAll)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int res = wsasendmsg(sock, reinterpret_cast<LPWSAMSG>(const_cast<OrbisNetMsghdr*>(msg)), flags,
|
||||
&bytesSent, nullptr, nullptr);
|
||||
return totalSent;
|
||||
|
||||
if (res == SOCKET_ERROR) {
|
||||
return ConvertReturnErrorCode(-1);
|
||||
}
|
||||
return static_cast<int>(bytesSent);
|
||||
#else
|
||||
int res = sendmsg(sock, reinterpret_cast<const msghdr*>(msg), flags);
|
||||
int native_flags = convertOrbisFlagsToPosix(socket_type, flags);
|
||||
int res = sendmsg(sock, reinterpret_cast<const msghdr*>(msg), native_flags);
|
||||
return ConvertReturnErrorCode(res);
|
||||
#endif
|
||||
}
|
||||
@ -213,37 +288,92 @@ int PosixSocket::SendMessage(const OrbisNetMsghdr* msg, int flags) {
|
||||
int PosixSocket::SendPacket(const void* msg, u32 len, int flags, const OrbisNetSockaddr* to,
|
||||
u32 tolen) {
|
||||
std::scoped_lock lock{m_mutex};
|
||||
if (to != nullptr) {
|
||||
sockaddr addr;
|
||||
convertOrbisNetSockaddrToPosix(to, &addr);
|
||||
return ConvertReturnErrorCode(
|
||||
sendto(sock, (const char*)msg, len, flags, &addr, sizeof(sockaddr_in)));
|
||||
} else {
|
||||
return ConvertReturnErrorCode(send(sock, (const char*)msg, len, flags));
|
||||
int res = 0;
|
||||
#ifdef _WIN32
|
||||
if (flags & ORBIS_NET_MSG_DONTWAIT) {
|
||||
res = socket_is_ready(sock, false);
|
||||
if (res <= 0)
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
const auto posix_flags = convertOrbisFlagsToPosix(socket_type, flags);
|
||||
if (to == nullptr) {
|
||||
res = send(sock, (const char*)msg, len, posix_flags);
|
||||
} else {
|
||||
sockaddr addr{};
|
||||
convertOrbisNetSockaddrToPosix(to, &addr);
|
||||
res = sendto(sock, (const char*)msg, len, posix_flags, &addr, tolen);
|
||||
}
|
||||
return ConvertReturnErrorCode(res);
|
||||
}
|
||||
|
||||
int PosixSocket::ReceiveMessage(OrbisNetMsghdr* msg, int flags) {
|
||||
std::scoped_lock lock{receive_mutex};
|
||||
|
||||
#ifdef _WIN32
|
||||
LPFN_WSARECVMSG wsarecvmsg = nullptr;
|
||||
GUID guid = WSAID_WSARECVMSG;
|
||||
DWORD bytes = 0;
|
||||
int totalReceived = 0;
|
||||
bool waitAll = (flags & ORBIS_NET_MSG_WAITALL) != 0;
|
||||
bool dontWait = (flags & ORBIS_NET_MSG_DONTWAIT) != 0;
|
||||
|
||||
if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid), &wsarecvmsg,
|
||||
sizeof(wsarecvmsg), &bytes, nullptr, nullptr) != 0) {
|
||||
return ConvertReturnErrorCode(-1);
|
||||
// stream socket with multiple buffers
|
||||
bool use_wsarecvmsg =
|
||||
(socket_type == ORBIS_NET_SOCK_STREAM || socket_type == ORBIS_NET_SOCK_STREAM_P2P) &&
|
||||
msg->msg_iovlen > 1;
|
||||
|
||||
for (int i = 0; i < msg->msg_iovlen; ++i) {
|
||||
char* buf = (char*)msg->msg_iov[i].iov_base;
|
||||
size_t remaining = msg->msg_iov[i].iov_len;
|
||||
|
||||
while (remaining > 0) {
|
||||
// emulate DONTWAIT
|
||||
if (dontWait) {
|
||||
int ready = socket_is_ready(sock, true);
|
||||
if (ready <= 0)
|
||||
return ready; // returns ORBIS_NET_ERROR_EWOULDBLOCK or error
|
||||
}
|
||||
|
||||
int received = 0;
|
||||
if (use_wsarecvmsg) {
|
||||
// only call WSARecvMsg if multiple buffers + stream
|
||||
LPFN_WSARECVMSG wsarecvmsg = nullptr;
|
||||
GUID guid = WSAID_WSARECVMSG;
|
||||
DWORD bytes = 0;
|
||||
if (WSAIoctl(sock, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid, sizeof(guid),
|
||||
&wsarecvmsg, sizeof(wsarecvmsg), &bytes, nullptr, nullptr) != 0) {
|
||||
// fallback to recv()
|
||||
received = ::recv(sock, buf, remaining, 0);
|
||||
if (received == SOCKET_ERROR)
|
||||
return ConvertReturnErrorCode(WSAGetLastError());
|
||||
} else {
|
||||
DWORD bytesReceived = 0;
|
||||
int res = wsarecvmsg(sock, reinterpret_cast<LPWSAMSG>(msg), &bytesReceived,
|
||||
nullptr, nullptr);
|
||||
if (res == SOCKET_ERROR)
|
||||
return ConvertReturnErrorCode(WSAGetLastError());
|
||||
received = bytesReceived;
|
||||
}
|
||||
} else {
|
||||
// fallback to recv() for UDP or single-buffer
|
||||
received = ::recv(sock, buf, remaining, 0);
|
||||
if (received == SOCKET_ERROR)
|
||||
return ConvertReturnErrorCode(WSAGetLastError());
|
||||
}
|
||||
|
||||
totalReceived += received;
|
||||
remaining -= received;
|
||||
buf += received;
|
||||
|
||||
// stop after first receive if WAITALL is not set
|
||||
if (!waitAll)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
DWORD bytesReceived = 0;
|
||||
int res = wsarecvmsg(sock, reinterpret_cast<LPWSAMSG>(msg), &bytesReceived, nullptr, nullptr);
|
||||
return totalReceived;
|
||||
|
||||
if (res == SOCKET_ERROR) {
|
||||
return ConvertReturnErrorCode(-1);
|
||||
}
|
||||
return static_cast<int>(bytesReceived);
|
||||
#else
|
||||
int res = recvmsg(sock, reinterpret_cast<msghdr*>(msg), flags);
|
||||
int native_flags = convertOrbisFlagsToPosix(socket_type, flags);
|
||||
int res = recvmsg(sock, reinterpret_cast<msghdr*>(msg), native_flags);
|
||||
return ConvertReturnErrorCode(res);
|
||||
#endif
|
||||
}
|
||||
@ -251,15 +381,27 @@ int PosixSocket::ReceiveMessage(OrbisNetMsghdr* msg, int flags) {
|
||||
int PosixSocket::ReceivePacket(void* buf, u32 len, int flags, OrbisNetSockaddr* from,
|
||||
u32* fromlen) {
|
||||
std::scoped_lock lock{receive_mutex};
|
||||
if (from != nullptr) {
|
||||
sockaddr addr;
|
||||
int res = recvfrom(sock, (char*)buf, len, flags, &addr, (socklen_t*)fromlen);
|
||||
convertPosixSockaddrToOrbis(&addr, from);
|
||||
*fromlen = sizeof(OrbisNetSockaddrIn);
|
||||
return ConvertReturnErrorCode(res);
|
||||
} else {
|
||||
return ConvertReturnErrorCode(recv(sock, (char*)buf, len, flags));
|
||||
int res = 0;
|
||||
#ifdef _WIN32
|
||||
if (flags & ORBIS_NET_MSG_DONTWAIT) {
|
||||
res = socket_is_ready(sock);
|
||||
if (res <= 0)
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
const auto posix_flags = convertOrbisFlagsToPosix(socket_type, flags);
|
||||
if (from == nullptr) {
|
||||
res = recv(sock, (char*)buf, len, posix_flags);
|
||||
} else {
|
||||
sockaddr addr{};
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
res = recvfrom(sock, (char*)buf, len, posix_flags, &addr,
|
||||
(fromlen && *fromlen <= sizeof(addr) ? (socklen_t*)fromlen : &addrlen));
|
||||
if (res > 0)
|
||||
convertPosixSockaddrToOrbis(&addr, from);
|
||||
}
|
||||
|
||||
return ConvertReturnErrorCode(res);
|
||||
}
|
||||
|
||||
SocketPtr PosixSocket::Accept(OrbisNetSockaddr* addr, u32* addrlen) {
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project
|
||||
// SPDX-FileCopyrightText: Copyright 2024-2026 shadPS4 Emulator Project
|
||||
// SPDX-License-Identifier: GPL-2.0-or-later
|
||||
|
||||
#pragma once
|
||||
@ -62,7 +62,7 @@ struct OrbisNetLinger {
|
||||
s32 l_linger;
|
||||
};
|
||||
struct Socket {
|
||||
explicit Socket(int domain, int type, int protocol) {}
|
||||
explicit Socket(int domain, int type, int protocol) : socket_type(type) {}
|
||||
virtual ~Socket() = default;
|
||||
virtual bool IsValid() const = 0;
|
||||
virtual int Close() = 0;
|
||||
@ -84,6 +84,7 @@ struct Socket {
|
||||
virtual std::optional<net_socket> Native() = 0;
|
||||
std::mutex m_mutex;
|
||||
std::mutex receive_mutex;
|
||||
int socket_type;
|
||||
};
|
||||
|
||||
struct PosixSocket : public Socket {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user