236 lines
5.3 KiB
C++
236 lines
5.3 KiB
C++
#include "stdafx.h"
|
|
#include "KOSocket.h"
|
|
#include "packets.h"
|
|
#include "version.h"
|
|
|
|
KOSocket::KOSocket(uint16 socketID, SocketMgr * mgr, SOCKET fd, uint32 sendBufferSize, uint32 recvBufferSize)
|
|
: Socket(fd, sendBufferSize, recvBufferSize),
|
|
m_socketID(socketID), m_remaining(0), m_usingCrypto(false),
|
|
m_readTries(0), m_sequence(0), m_lastResponse(0)
|
|
{
|
|
SetSocketMgr(mgr);
|
|
}
|
|
|
|
void KOSocket::OnConnect()
|
|
{
|
|
if (GetRemoteIP() == "127.0.0.1")
|
|
TRACE("Connection received from %s:%d\n", GetRemoteIP().c_str(), GetRemotePort());
|
|
|
|
m_remaining = 0;
|
|
m_usingCrypto = false;
|
|
m_readTries = 0;
|
|
m_sequence = 0;
|
|
m_lastResponse = UNIXTIME;
|
|
}
|
|
|
|
void KOSocket::OnRead()
|
|
{
|
|
Packet pkt;
|
|
|
|
for (;;)
|
|
{
|
|
if (m_remaining == 0)
|
|
{
|
|
if (GetReadBuffer().GetSize() < 5)
|
|
return; //check for opcode as well
|
|
|
|
uint16 header = 0;
|
|
GetReadBuffer().Read(&header, 2);
|
|
if (header != 0x55AA)
|
|
{
|
|
TRACE("%s: Got packet without header 0x55AA, got 0x%X\n", GetRemoteIP().c_str(), header);
|
|
goto error_handler;
|
|
}
|
|
|
|
GetReadBuffer().Read(&m_remaining, 2);
|
|
if (m_remaining == 0)
|
|
{
|
|
TRACE("%s: Got packet without an opcode, this should never happen.\n", GetRemoteIP().c_str());
|
|
goto error_handler;
|
|
}
|
|
}
|
|
|
|
if (m_remaining > GetReadBuffer().GetAllocatedSize())
|
|
{
|
|
TRACE("%s: Packet received which was %u bytes in size, maximum of %u.\n", GetRemoteIP().c_str(), m_remaining, GetReadBuffer().GetAllocatedSize());
|
|
goto error_handler;
|
|
}
|
|
|
|
if (m_remaining > GetReadBuffer().GetSize())
|
|
{
|
|
if (m_readTries > 4)
|
|
{
|
|
TRACE("%s: packet fragmentation count is over 4, disconnecting as they're probably up to something bad\n", GetRemoteIP().c_str());
|
|
goto error_handler;
|
|
}
|
|
m_readTries++;
|
|
return;
|
|
}
|
|
|
|
uint8 *in_stream = new uint8[m_remaining];
|
|
|
|
m_readTries = 0;
|
|
GetReadBuffer().Read(in_stream, m_remaining);
|
|
|
|
uint16 footer = 0;
|
|
GetReadBuffer().Read(&footer, 2);
|
|
|
|
if (footer != 0xAA55 || !DecryptPacket(in_stream, pkt))
|
|
{
|
|
TRACE("%s: Footer invalid (%X) or failed to decrypt.\n", GetRemoteIP().c_str(), footer);
|
|
delete [] in_stream;
|
|
goto error_handler;
|
|
}
|
|
|
|
delete [] in_stream;
|
|
|
|
m_lastResponse = UNIXTIME;
|
|
|
|
if (!HandlePacket(pkt))
|
|
{
|
|
TRACE("%s: Handler for packet %X returned false\n", GetRemoteIP().c_str(), pkt.GetOpcode());
|
|
#ifndef _DEBUG
|
|
goto error_handler;
|
|
#endif
|
|
}
|
|
|
|
// Update the time of the last (valid) response from the client.
|
|
|
|
m_remaining = 0;
|
|
}
|
|
|
|
return;
|
|
|
|
error_handler:
|
|
Disconnect();
|
|
}
|
|
|
|
bool KOSocket::DecryptPacket(uint8 *in_stream, Packet & pkt)
|
|
{
|
|
uint8* final_packet = nullptr;
|
|
|
|
if (isCryptoEnabled())
|
|
{
|
|
// Invalid packet (all encrypted packets need a CRC32 checksum!)
|
|
if (m_remaining < 4
|
|
// Invalid checksum
|
|
|| m_crypto.JvDecryptionWithCRC32(m_remaining, in_stream, in_stream) < 0
|
|
// Invalid sequence ID
|
|
|| ++m_sequence != *(uint32 *)(in_stream))
|
|
return false;
|
|
|
|
m_remaining -= 8; // remove the sequence ID & CRC checksum
|
|
final_packet = &in_stream[4];
|
|
}
|
|
else
|
|
{
|
|
final_packet = in_stream; // for simplicity :P
|
|
}
|
|
|
|
m_remaining--;
|
|
pkt = Packet(final_packet[0], (size_t)m_remaining);
|
|
if (m_remaining > 0)
|
|
{
|
|
pkt.resize(m_remaining);
|
|
memcpy((void*)pkt.contents(), &final_packet[1], m_remaining);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool KOSocket::Send(Packet * pkt)
|
|
{
|
|
if (!IsConnected() || pkt->size() + 1 > GetWriteBuffer().GetAllocatedSize())
|
|
return false;
|
|
|
|
bool r;
|
|
|
|
uint8 opcode = pkt->GetOpcode();
|
|
uint8 * out_stream = nullptr;
|
|
uint16 len = (uint16)(pkt->size() + 1);
|
|
|
|
if (isCryptoEnabled())
|
|
{
|
|
len += 5;
|
|
|
|
out_stream = new uint8[len];
|
|
|
|
*(uint16 *)&out_stream[0] = 0x1efc;
|
|
*(uint16 *)&out_stream[2] = (uint16)(m_sequence); // this isn't actually incremented here
|
|
out_stream[4] = 0;
|
|
out_stream[5] = pkt->GetOpcode();
|
|
|
|
if (pkt->size() > 0)
|
|
memcpy(&out_stream[6], pkt->contents(), pkt->size());
|
|
|
|
m_crypto.JvEncryptionFast(len, out_stream, out_stream);
|
|
}
|
|
else
|
|
{
|
|
out_stream = new uint8[len];
|
|
out_stream[0] = pkt->GetOpcode();
|
|
if (pkt->size() > 0)
|
|
memcpy(&out_stream[1], pkt->contents(), pkt->size());
|
|
}
|
|
|
|
|
|
BurstBegin();
|
|
if (GetWriteBuffer().GetSpace() < size_t(len + 6))
|
|
{
|
|
TRACE("Disonnected due to insufficient buffer space [%d]. \n", GetSocketID());
|
|
BurstEnd();
|
|
Disconnect();
|
|
return false;
|
|
}
|
|
r = BurstSend((const uint8*)"\xaa\x55", 2);
|
|
if (r) r = BurstSend((const uint8*)&len, 2);
|
|
if (r) r = BurstSend((const uint8*)out_stream, len);
|
|
if (r) r = BurstSend((const uint8*)"\x55\xaa", 2);
|
|
if (r) BurstPush();
|
|
BurstEnd();
|
|
|
|
delete [] out_stream;
|
|
return r;
|
|
}
|
|
|
|
bool KOSocket::SendCompressed(Packet * pkt)
|
|
{
|
|
if (pkt->size() < 500)
|
|
return Send(pkt);
|
|
|
|
Packet result(WIZ_COMPRESS_PACKET);
|
|
uint32 inLength = pkt->size() + 1, outLength = inLength + LZF_MARGIN, crc;
|
|
uint8 *buffer = new uint8[inLength], *outBuffer = new uint8[outLength];
|
|
|
|
*buffer = pkt->GetOpcode();
|
|
if (pkt->size() > 0)
|
|
memcpy(buffer + 1, pkt->contents(), pkt->size());
|
|
|
|
crc = (uint32)crc32(buffer, inLength, 0);
|
|
outLength = lzf_compress(buffer, inLength, outBuffer, outLength);
|
|
|
|
result << outLength << inLength;
|
|
result << uint32(crc);
|
|
|
|
result.append(outBuffer, outLength);
|
|
|
|
delete [] buffer;
|
|
delete [] outBuffer;
|
|
|
|
return Send(&result);
|
|
}
|
|
|
|
void KOSocket::OnDisconnect()
|
|
{
|
|
if (GetRemoteIP() == "127.0.0.1")
|
|
printf("Connection closed from %s:%d\n", GetRemoteIP().c_str(), GetRemotePort());
|
|
}
|
|
|
|
void KOSocket::EnableCrypto()
|
|
{
|
|
#ifdef USE_CRYPTION
|
|
m_crypto.Init();
|
|
m_usingCrypto = true;
|
|
#endif
|
|
}
|