| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418 |
- /*
- * Copyright (c) 2014, Oculus VR, Inc.
- * All rights reserved.
- *
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree. An additional grant
- * of patent rights can be found in the PATENTS file in the same directory.
- *
- */
- #include "NativeFeatureIncludes.h"
- #if _RAKNET_SUPPORT_MessageFilter==1
- #include "MessageFilter.h"
- #include "RakAssert.h"
- #include "GetTime.h"
- #include "MessageIdentifiers.h"
- #include "RakAssert.h"
- #include "RakPeerInterface.h"
- #include "PacketizedTCP.h"
- #include "BitStream.h"
- #ifdef _MSC_VER
- #pragma warning( push )
- #endif
- using namespace RakNet;
- int RakNet::MessageFilterStrComp( char *const &key,char *const &data )
- {
- return strcmp(key,data);
- }
- int RakNet::FilterSetComp( const int &key, FilterSet * const &data )
- {
- if (key < data->filterSetID)
- return -1;
- else if (key==data->filterSetID)
- return 0;
- else
- return 1;
- }
- STATIC_FACTORY_DEFINITIONS(MessageFilter,MessageFilter);
- MessageFilter::MessageFilter()
- {
- whenLastTimeoutCheck=RakNet::GetTime();
- }
- MessageFilter::~MessageFilter()
- {
- Clear();
- }
- void MessageFilter::SetAutoAddNewConnectionsToFilter(int filterSetID)
- {
- autoAddNewConnectionsToFilter=filterSetID;
- }
- void MessageFilter::SetAllowMessageID(bool allow, int messageIDStart, int messageIDEnd,int filterSetID)
- {
- RakAssert(messageIDStart <= messageIDEnd);
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- int i;
- for (i=messageIDStart; i <= messageIDEnd; ++i)
- filterSet->allowedIDs[i]=allow;
- }
- void MessageFilter::SetAllowRPC4(bool allow, const char* uniqueID, int filterSetID)
- {
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- bool objectExists;
- unsigned int idx = filterSet->allowedRPC4.GetIndexFromKey(uniqueID, &objectExists);
- if (allow)
- {
- if (objectExists==false)
- {
- filterSet->allowedRPC4.InsertAtIndex(uniqueID, idx, _FILE_AND_LINE_);
- filterSet->allowedIDs[ID_RPC_PLUGIN]=true;
- }
- }
- else
- {
- if (objectExists==true)
- {
- filterSet->allowedRPC4.RemoveAtIndex(idx);
- if (filterSet->allowedRPC4.Size()==0)
- {
- filterSet->allowedIDs[ID_RPC_PLUGIN]=false;
- }
- }
- }
- }
- void MessageFilter::SetActionOnDisallowedMessage(bool kickOnDisallowed, bool banOnDisallowed, RakNet::TimeMS banTimeMS, int filterSetID)
- {
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- filterSet->kickOnDisallowedMessage=kickOnDisallowed;
- filterSet->disallowedMessageBanTimeMS=banTimeMS;
- filterSet->banOnDisallowedMessage=banOnDisallowed;
- }
- void MessageFilter::SetDisallowedMessageCallback(int filterSetID, void *userData, void (*invalidMessageCallback)(RakPeerInterface *peer, AddressOrGUID systemAddress, int filterSetID, void *userData, unsigned char messageID))
- {
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- filterSet->invalidMessageCallback=invalidMessageCallback;
- filterSet->disallowedCallbackUserData=userData;
- }
- void MessageFilter::SetTimeoutCallback(int filterSetID, void *userData, void (*invalidMessageCallback)(RakPeerInterface *peer, AddressOrGUID systemAddress, int filterSetID, void *userData))
- {
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- filterSet->timeoutCallback=invalidMessageCallback;
- filterSet->timeoutUserData=userData;
- }
- void MessageFilter::SetFilterMaxTime(int allowedTimeMS, bool banOnExceed, RakNet::TimeMS banTimeMS, int filterSetID)
- {
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- filterSet->maxMemberTimeMS=allowedTimeMS;
- filterSet->banOnFilterTimeExceed=banOnExceed;
- filterSet->timeExceedBanTimeMS=banTimeMS;
- }
- int MessageFilter::GetSystemFilterSet(AddressOrGUID systemAddress)
- {
- // bool objectExists;
- // unsigned index = systemList.GetIndexFromKey(systemAddress, &objectExists);
- // if (objectExists==false)
- // return -1;
- // else
- // return systemList[index].filter->filterSetID;
- DataStructures::HashIndex index = systemList.GetIndexOf(systemAddress);
- if (index.IsInvalid())
- return -1;
- else
- return systemList.ItemAtIndex(index).filter->filterSetID;
- }
- void MessageFilter::SetSystemFilterSet(AddressOrGUID addressOrGUID, int filterSetID)
- {
- // Allocate this filter set if it doesn't exist.
- RakAssert(addressOrGUID.IsUndefined()==false);
- // bool objectExists;
- // unsigned index = systemList.GetIndexFromKey(addressOrGUID, &objectExists);
- // if (objectExists==false)
- DataStructures::HashIndex index = systemList.GetIndexOf(addressOrGUID);
- if (index.IsInvalid())
- {
- if (filterSetID<0)
- return;
- FilteredSystem filteredSystem;
- filteredSystem.filter = GetFilterSetByID(filterSetID);
- // filteredSystem.addressOrGUID=addressOrGUID;
- filteredSystem.timeEnteredThisSet=RakNet::GetTimeMS();
- // systemList.Insert(addressOrGUID, filteredSystem, true, _FILE_AND_LINE_);
- systemList.Push(addressOrGUID,filteredSystem,_FILE_AND_LINE_);
- }
- else
- {
- if (filterSetID>=0)
- {
- FilterSet *filterSet = GetFilterSetByID(filterSetID);
- systemList.ItemAtIndex(index).timeEnteredThisSet=RakNet::GetTimeMS();
- systemList.ItemAtIndex(index).filter=filterSet;
- }
- else
- {
- systemList.RemoveAtIndex(index, _FILE_AND_LINE_);
- }
- }
- }
- unsigned MessageFilter::GetSystemCount(int filterSetID) const
- {
- if (filterSetID==-1)
- {
- return systemList.Size();
- }
- else
- {
- unsigned i;
- unsigned count=0;
- DataStructures::List< FilteredSystem > itemList;
- DataStructures::List< AddressOrGUID > keyList;
- systemList.GetAsList(itemList, keyList, _FILE_AND_LINE_);
- for (i=0; i < itemList.Size(); i++)
- if (itemList[i].filter->filterSetID==filterSetID)
- ++count;
- return count;
- }
- }
- unsigned MessageFilter::GetFilterSetCount(void) const
- {
- return filterList.Size();
- }
- int MessageFilter::GetFilterSetIDByIndex(unsigned index)
- {
- return filterList[index]->filterSetID;
- }
- void MessageFilter::DeleteFilterSet(int filterSetID)
- {
- FilterSet *filterSet;
- bool objectExists;
- unsigned i,index;
- index = filterList.GetIndexFromKey(filterSetID, &objectExists);
- if (objectExists)
- {
- filterSet=filterList[index];
- DeallocateFilterSet(filterSet);
- filterList.RemoveAtIndex(index);
- DataStructures::List< FilteredSystem > itemList;
- DataStructures::List< AddressOrGUID > keyList;
- systemList.GetAsList(itemList, keyList, _FILE_AND_LINE_);
- for (i=0; i < itemList.Size(); i++)
- {
- if (itemList[i].filter==filterSet)
- {
- systemList.Remove(keyList[i], _FILE_AND_LINE_);
- }
- }
- /*
- // Don't reference this pointer any longer
- i=0;
- while (i < systemList.Size())
- {
- if (systemList[i].filter==filterSet)
- systemList.RemoveAtIndex(i);
- else
- ++i;
- }
- */
- }
- }
- void MessageFilter::Clear(void)
- {
- unsigned i;
- systemList.Clear(_FILE_AND_LINE_);
- for (i=0; i < filterList.Size(); i++)
- DeallocateFilterSet(filterList[i]);
- filterList.Clear(false, _FILE_AND_LINE_);
- }
- void MessageFilter::DeallocateFilterSet(FilterSet* filterSet)
- {
- RakNet::OP_DELETE(filterSet, _FILE_AND_LINE_);
- }
- FilterSet* MessageFilter::GetFilterSetByID(int filterSetID)
- {
- RakAssert(filterSetID>=0);
- bool objectExists;
- unsigned index;
- index = filterList.GetIndexFromKey(filterSetID, &objectExists);
- if (objectExists)
- return filterList[index];
- else
- {
- FilterSet *newFilterSet = RakNet::OP_NEW<FilterSet>( _FILE_AND_LINE_ );
- memset(newFilterSet->allowedIDs, 0, MESSAGE_FILTER_MAX_MESSAGE_ID * sizeof(bool));
- newFilterSet->banOnFilterTimeExceed=false;
- newFilterSet->kickOnDisallowedMessage=false;
- newFilterSet->banOnDisallowedMessage=false;
- newFilterSet->disallowedMessageBanTimeMS=0;
- newFilterSet->timeExceedBanTimeMS=0;
- newFilterSet->maxMemberTimeMS=0;
- newFilterSet->filterSetID=filterSetID;
- newFilterSet->invalidMessageCallback=0;
- newFilterSet->timeoutCallback=0;
- newFilterSet->timeoutUserData=0;
- filterList.Insert(filterSetID, newFilterSet, true, _FILE_AND_LINE_);
- return newFilterSet;
- }
- }
- void MessageFilter::OnInvalidMessage(FilterSet *filterSet, AddressOrGUID systemAddress, unsigned char messageID)
- {
- if (filterSet->invalidMessageCallback)
- filterSet->invalidMessageCallback(rakPeerInterface, systemAddress, filterSet->filterSetID, filterSet->disallowedCallbackUserData, messageID);
- if (filterSet->banOnDisallowedMessage && rakPeerInterface)
- {
- char str1[64];
- systemAddress.systemAddress.ToString(false, str1);
- rakPeerInterface->AddToBanList(str1, filterSet->disallowedMessageBanTimeMS);
- }
- if (filterSet->kickOnDisallowedMessage)
- {
- if (rakPeerInterface)
- rakPeerInterface->CloseConnection(systemAddress, true, 0);
- #if _RAKNET_SUPPORT_PacketizedTCP==1 && _RAKNET_SUPPORT_TCPInterface==1
- else
- tcpInterface->CloseConnection(systemAddress.systemAddress);
- #endif
- }
- }
- void MessageFilter::Update(void)
- {
- // Update all timers for all systems. If those systems' filter sets are expired, take the appropriate action.
- RakNet::Time curTime = RakNet::GetTime();
- if (GreaterThan(curTime - 1000, whenLastTimeoutCheck))
- {
- DataStructures::List< FilteredSystem > itemList;
- DataStructures::List< AddressOrGUID > keyList;
- systemList.GetAsList(itemList, keyList, _FILE_AND_LINE_);
- unsigned int index;
- for (index=0; index < itemList.Size(); index++)
- {
- if (itemList[index].filter &&
- itemList[index].filter->maxMemberTimeMS>0 &&
- curTime-itemList[index].timeEnteredThisSet >= itemList[index].filter->maxMemberTimeMS)
- {
- if (itemList[index].filter->timeoutCallback)
- itemList[index].filter->timeoutCallback(rakPeerInterface, keyList[index], itemList[index].filter->filterSetID, itemList[index].filter->timeoutUserData);
- if (itemList[index].filter->banOnFilterTimeExceed && rakPeerInterface)
- {
- char str1[64];
- keyList[index].ToString(false, str1);
- rakPeerInterface->AddToBanList(str1, itemList[index].filter->timeExceedBanTimeMS);
- }
- if (rakPeerInterface)
- rakPeerInterface->CloseConnection(keyList[index], true, 0);
- #if _RAKNET_SUPPORT_PacketizedTCP==1 && _RAKNET_SUPPORT_TCPInterface==1
- else
- tcpInterface->CloseConnection(keyList[index].systemAddress);
- #endif
- systemList.Remove(keyList[index], _FILE_AND_LINE_);
- }
- }
- whenLastTimeoutCheck=curTime+1000;
- }
- }
- void MessageFilter::OnNewConnection(const SystemAddress &systemAddress, RakNetGUID rakNetGUID, bool isIncoming)
- {
- (void) systemAddress;
- (void) rakNetGUID;
- (void) isIncoming;
- AddressOrGUID aog;
- aog.rakNetGuid=rakNetGUID;
- aog.systemAddress=systemAddress;
- // New system, automatically assign to filter set if appropriate
- if (autoAddNewConnectionsToFilter>=0 && systemList.HasData(aog)==false)
- SetSystemFilterSet(aog, autoAddNewConnectionsToFilter);
- }
- void MessageFilter::OnClosedConnection(const SystemAddress &systemAddress, RakNetGUID rakNetGUID, PI2_LostConnectionReason lostConnectionReason )
- {
- (void) rakNetGUID;
- (void) lostConnectionReason;
- AddressOrGUID aog;
- aog.rakNetGuid=rakNetGUID;
- aog.systemAddress=systemAddress;
- // Lost system, remove from the list
- systemList.Remove(aog, _FILE_AND_LINE_);
- }
- PluginReceiveResult MessageFilter::OnReceive(Packet *packet)
- {
- DataStructures::HashIndex index;
- unsigned char messageId;
- switch (packet->data[0])
- {
- case ID_NEW_INCOMING_CONNECTION:
- case ID_CONNECTION_REQUEST_ACCEPTED:
- case ID_CONNECTION_LOST:
- case ID_DISCONNECTION_NOTIFICATION:
- case ID_CONNECTION_ATTEMPT_FAILED:
- case ID_NO_FREE_INCOMING_CONNECTIONS:
- case ID_IP_RECENTLY_CONNECTED:
- case ID_CONNECTION_BANNED:
- case ID_INVALID_PASSWORD:
- case ID_UNCONNECTED_PONG:
- case ID_ALREADY_CONNECTED:
- case ID_ADVERTISE_SYSTEM:
- case ID_REMOTE_DISCONNECTION_NOTIFICATION:
- case ID_REMOTE_CONNECTION_LOST:
- case ID_REMOTE_NEW_INCOMING_CONNECTION:
- case ID_DOWNLOAD_PROGRESS:
- break;
- default:
- if (packet->data[0]==ID_TIMESTAMP)
- {
- if (packet->length<sizeof(MessageID) + sizeof(RakNet::TimeMS))
- return RR_STOP_PROCESSING_AND_DEALLOCATE; // Invalid message
- messageId=packet->data[sizeof(MessageID) + sizeof(RakNet::TimeMS)];
- }
- else
- messageId=packet->data[0];
- // If this system is filtered, check if this message is allowed. If not allowed, return RR_STOP_PROCESSING_AND_DEALLOCATE
- // index = systemList.GetIndexFromKey(packet->addressOrGUID, &objectExists);
- index = systemList.GetIndexOf(packet);
- if (index.IsInvalid())
- break;
- if (systemList.ItemAtIndex(index).filter->allowedIDs[messageId]==false)
- {
- OnInvalidMessage(systemList.ItemAtIndex(index).filter, packet, packet->data[0]);
- return RR_STOP_PROCESSING_AND_DEALLOCATE;
- }
- if (packet->data[0]==ID_RPC_PLUGIN)
- {
- RakNet::BitStream bsIn(packet->data,packet->length,false);
- bsIn.IgnoreBytes(2);
- RakNet::RakString functionName;
- bsIn.ReadCompressed(functionName);
- if (systemList.ItemAtIndex(index).filter->allowedRPC4.HasData(functionName)==false)
- {
- OnInvalidMessage(systemList.ItemAtIndex(index).filter, packet, packet->data[0]);
- return RR_STOP_PROCESSING_AND_DEALLOCATE;
- }
- }
-
- break;
- }
-
- return RR_CONTINUE_PROCESSING;
- }
- #ifdef _MSC_VER
- #pragma warning( pop )
- #endif
- #endif // _RAKNET_SUPPORT_*
|