-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathUnauthenticatedSessionTable.h
203 lines (175 loc) · 6.72 KB
/
UnauthenticatedSessionTable.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
/*
*
* Copyright (c) 2020 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <core/CHIPError.h>
#include <core/ReferenceCounted.h>
#include <support/CodeUtils.h>
#include <support/Pool.h>
#include <support/ReferenceCountedHandle.h>
#include <support/logging/CHIPLogging.h>
#include <system/TimeSource.h>
#include <transport/raw/PeerAddress.h>
namespace chip {
namespace Transport {
class UnauthenticatedSession;
using UnauthenticatedSessionHandle = ReferenceCountedHandle<UnauthenticatedSession>;
class UnauthenticatedSessionDeleter
{
public:
// This is a no-op because life-cycle of UnauthenticatedSessionTable is rotated by LRU
static void Release(UnauthenticatedSession * entry) {}
};
/**
* @brief
* An UnauthenticatedSession stores the binding of TransportAddress, and message counters.
*
* The entries are rotated using LRU, but entry can be hold by using UnauthenticatedSessionHandle, which increase the reference
* count by 1. If the reference count is not 0, the entry won't be pruned.
*/
class UnauthenticatedSession : public ReferenceCounted<UnauthenticatedSession, UnauthenticatedSessionDeleter>
{
public:
UnauthenticatedSession(const PeerAddress & address) : mPeerAddress(address) {}
UnauthenticatedSession(const UnauthenticatedSession &) = delete;
UnauthenticatedSession & operator=(const UnauthenticatedSession &) = delete;
UnauthenticatedSession(UnauthenticatedSession &&) = delete;
UnauthenticatedSession & operator=(UnauthenticatedSession &&) = delete;
uint64_t GetLastActivityTimeMs() const { return mLastActivityTimeMs; }
void SetLastActivityTimeMs(uint64_t value) { mLastActivityTimeMs = value; }
const PeerAddress & GetPeerAddress() const { return mPeerAddress; }
MessageCounter & GetLocalMessageCounter() { return mLocalMessageCounter; }
PeerMessageCounter & GetPeerMessageCounter() { return mPeerMessageCounter; }
private:
uint64_t mLastActivityTimeMs = 0;
const PeerAddress mPeerAddress;
GlobalUnencryptedMessageCounter mLocalMessageCounter;
PeerMessageCounter mPeerMessageCounter;
};
template <size_t kMaxConnectionCount, Time::Source kTimeSource = Time::Source::kSystem>
class UnauthenticatedSessionTable
{
public:
/**
* Allocates a new session out of the internal resource pool.
*
* @returns CHIP_NO_ERROR if new session created. May fail if maximum connection count has been reached (with
* CHIP_ERROR_NO_MEMORY).
*/
CHECK_RETURN_VALUE
CHIP_ERROR AllocEntry(const PeerAddress & address, UnauthenticatedSession *& entry)
{
entry = mEntries.CreateObject(address);
if (entry != nullptr)
return CHIP_NO_ERROR;
entry = FindLeastRecentUsedEntry();
if (entry == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
const uint64_t currentTime = mTimeSource.GetCurrentMonotonicTimeMs();
if (currentTime - entry->GetLastActivityTimeMs() < kMinimalActivityTimeMs)
{
// Protect the entry for a short period to prevent from rotating too fast.
entry = nullptr;
return CHIP_ERROR_NO_MEMORY;
}
mEntries.ResetObject(entry, address);
return CHIP_NO_ERROR;
}
/**
* Get a session using given address
*
* @return the peer found, nullptr if not found
*/
CHECK_RETURN_VALUE
UnauthenticatedSession * FindEntry(const PeerAddress & address)
{
UnauthenticatedSession * result = nullptr;
mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) {
if (MatchPeerAddress(entry->GetPeerAddress(), address))
{
result = entry;
return false;
}
return true;
});
return result;
}
/**
* Get a peer given the peer id. If the peer doesn't exist in the cache, allocate a new entry for it.
*
* @return the peer found or allocated, nullptr if not found and allocate failed.
*/
CHECK_RETURN_VALUE
UnauthenticatedSession * FindOrAllocateEntry(const PeerAddress & address)
{
UnauthenticatedSession * result = FindEntry(address);
if (result != nullptr)
return result;
CHIP_ERROR err = AllocEntry(address, result);
if (err == CHIP_NO_ERROR)
{
return result;
}
else
{
return nullptr;
}
}
/// Mark a session as active
void MarkSessionActive(UnauthenticatedSession & entry) { entry.SetLastActivityTimeMs(mTimeSource.GetCurrentMonotonicTimeMs()); }
/// Allows access to the underlying time source used for keeping track of connection active time
Time::TimeSource<kTimeSource> & GetTimeSource() { return mTimeSource; }
private:
UnauthenticatedSession * FindLeastRecentUsedEntry()
{
UnauthenticatedSession * result = nullptr;
uint64_t oldestTimeMs = std::numeric_limits<uint64_t>::max();
mEntries.ForEachActiveObject([&](UnauthenticatedSession * entry) {
if (entry->GetReferenceCount() == 0 && entry->GetLastActivityTimeMs() < oldestTimeMs)
{
result = entry;
oldestTimeMs = entry->GetLastActivityTimeMs();
}
return true;
});
return result;
}
static bool MatchPeerAddress(const PeerAddress & a1, const PeerAddress & a2)
{
if (a1.GetTransportType() != a2.GetTransportType())
return false;
switch (a1.GetTransportType())
{
case Transport::Type::kUndefined:
return false;
case Transport::Type::kUdp:
case Transport::Type::kTcp:
// Ingore interface in the address
return a1.GetIPAddress() == a2.GetIPAddress() && a1.GetPort() == a2.GetPort();
case Transport::Type::kBle:
// TODO: complete BLE address comparation
return true;
}
return false;
}
static constexpr uint64_t kMinimalActivityTimeMs = 30000;
Time::TimeSource<Time::Source::kSystem> mTimeSource;
BitMapObjectPool<UnauthenticatedSession, kMaxConnectionCount> mEntries;
};
} // namespace Transport
} // namespace chip