-
Notifications
You must be signed in to change notification settings - Fork 44
/
Copy pathLibDiamond.sol
262 lines (242 loc) · 12.6 KB
/
LibDiamond.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;
/******************************************************************************\
* Author: Nick Mudge <[email protected]> (https://twitter.com/mudgen)
* EIP-2535 Diamonds: https://eips.ethereum.org/EIPS/eip-2535
/******************************************************************************/
import { IDiamondCut } from "../interfaces/IDiamondCut.sol";
// Remember to add the loupe functions from DiamondLoupeFacet to the diamond.
// The loupe functions are required by the EIP2535 Diamonds standard
error InitializationFunctionReverted(address _initializationContractAddress, bytes _calldata);
library LibDiamond {
bytes32 constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage");
struct DiamondStorage {
// maps function selectors to the facets that execute the functions.
// and maps the selectors to their position in the selectorSlots array.
// func selector => address facet, selector position
mapping(bytes4 => bytes32) facets;
// array of slots of function selectors.
// each slot holds 8 function selectors.
mapping(uint256 => bytes32) selectorSlots;
// The number of function selectors in selectorSlots
uint16 selectorCount;
// Used to query if a contract implements an interface.
// Used to implement ERC-165.
mapping(bytes4 => bool) supportedInterfaces;
// owner of the contract
address contractOwner;
}
function diamondStorage() internal pure returns (DiamondStorage storage ds) {
bytes32 position = DIAMOND_STORAGE_POSITION;
assembly {
ds.slot := position
}
}
event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);
function setContractOwner(address _newOwner) internal {
DiamondStorage storage ds = diamondStorage();
address previousOwner = ds.contractOwner;
ds.contractOwner = _newOwner;
emit OwnershipTransferred(previousOwner, _newOwner);
}
function contractOwner() internal view returns (address contractOwner_) {
contractOwner_ = diamondStorage().contractOwner;
}
function enforceIsContractOwner() internal view {
require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner");
}
event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata);
bytes32 constant CLEAR_ADDRESS_MASK = bytes32(uint256(0xffffffffffffffffffffffff));
bytes32 constant CLEAR_SELECTOR_MASK = bytes32(uint256(0xffffffff << 224));
// Internal function version of diamondCut
// This code is almost the same as the external diamondCut,
// except it is using 'Facet[] memory _diamondCut' instead of
// 'Facet[] calldata _diamondCut'.
// The code is duplicated to prevent copying calldata to memory which
// causes an error for a two dimensional array.
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
DiamondStorage storage ds = diamondStorage();
uint256 originalSelectorCount = ds.selectorCount;
uint256 selectorCount = originalSelectorCount;
bytes32 selectorSlot;
// Check if last selector slot is not full
// "selectorCount & 7" is a gas efficient modulo by eight "selectorCount % 8"
if (selectorCount & 7 > 0) {
// get last selectorSlot
// "selectorSlot >> 3" is a gas efficient division by 8 "selectorSlot / 8"
selectorSlot = ds.selectorSlots[selectorCount >> 3];
}
// loop through diamond cut
for (uint256 facetIndex; facetIndex < _diamondCut.length; ) {
(selectorCount, selectorSlot) = addReplaceRemoveFacetSelectors(
selectorCount,
selectorSlot,
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].action,
_diamondCut[facetIndex].functionSelectors
);
unchecked {
facetIndex++;
}
}
if (selectorCount != originalSelectorCount) {
ds.selectorCount = uint16(selectorCount);
}
// If last selector slot is not full
// "selectorCount & 7" is a gas efficient modulo by eight "selectorCount % 8"
if (selectorCount & 7 > 0) {
// "selectorSlot >> 3" is a gas efficient division by 8 "selectorSlot / 8"
ds.selectorSlots[selectorCount >> 3] = selectorSlot;
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}
function addReplaceRemoveFacetSelectors(
uint256 _selectorCount,
bytes32 _selectorSlot,
address _newFacetAddress,
IDiamondCut.FacetCutAction _action,
bytes4[] memory _selectors
) internal returns (uint256, bytes32) {
DiamondStorage storage ds = diamondStorage();
require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
if (_action == IDiamondCut.FacetCutAction.Add) {
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Add facet has no code");
for (uint256 selectorIndex; selectorIndex < _selectors.length; ) {
bytes4 selector = _selectors[selectorIndex];
bytes32 oldFacet = ds.facets[selector];
require(address(bytes20(oldFacet)) == address(0), "LibDiamondCut: Can't add function that already exists");
// add facet for selector
ds.facets[selector] = bytes20(_newFacetAddress) | bytes32(_selectorCount);
// "_selectorCount & 7" is a gas efficient modulo by eight "_selectorCount % 8"
// " << 5 is the same as multiplying by 32 ( * 32)
uint256 selectorInSlotPosition = (_selectorCount & 7) << 5;
// clear selector position in slot and add selector
_selectorSlot = (_selectorSlot & ~(CLEAR_SELECTOR_MASK >> selectorInSlotPosition)) | (bytes32(selector) >> selectorInSlotPosition);
// if slot is full then write it to storage
if (selectorInSlotPosition == 224) {
// "_selectorSlot >> 3" is a gas efficient division by 8 "_selectorSlot / 8"
ds.selectorSlots[_selectorCount >> 3] = _selectorSlot;
_selectorSlot = 0;
}
_selectorCount++;
unchecked {
selectorIndex++;
}
}
} else if (_action == IDiamondCut.FacetCutAction.Replace) {
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Replace facet has no code");
for (uint256 selectorIndex; selectorIndex < _selectors.length; ) {
bytes4 selector = _selectors[selectorIndex];
bytes32 oldFacet = ds.facets[selector];
address oldFacetAddress = address(bytes20(oldFacet));
// only useful if immutable functions exist
require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function");
require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function");
require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist");
// replace old facet address
ds.facets[selector] = (oldFacet & CLEAR_ADDRESS_MASK) | bytes20(_newFacetAddress);
unchecked {
selectorIndex++;
}
}
} else if (_action == IDiamondCut.FacetCutAction.Remove) {
require(_newFacetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
// "_selectorCount >> 3" is a gas efficient division by 8 "_selectorCount / 8"
uint256 selectorSlotCount = _selectorCount >> 3;
// "_selectorCount & 7" is a gas efficient modulo by eight "_selectorCount % 8"
uint256 selectorInSlotIndex = _selectorCount & 7;
for (uint256 selectorIndex; selectorIndex < _selectors.length; ) {
if (selectorInSlotIndex == 0) {
// get last selectorSlot
selectorSlotCount--;
_selectorSlot = ds.selectorSlots[selectorSlotCount];
selectorInSlotIndex = 7;
} else {
selectorInSlotIndex--;
}
bytes4 lastSelector;
uint256 oldSelectorsSlotCount;
uint256 oldSelectorInSlotPosition;
// adding a block here prevents stack too deep error
{
bytes4 selector = _selectors[selectorIndex];
bytes32 oldFacet = ds.facets[selector];
require(address(bytes20(oldFacet)) != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
// only useful if immutable functions exist
require(address(bytes20(oldFacet)) != address(this), "LibDiamondCut: Can't remove immutable function");
// replace selector with last selector in ds.facets
// gets the last selector
// " << 5 is the same as multiplying by 32 ( * 32)
lastSelector = bytes4(_selectorSlot << (selectorInSlotIndex << 5));
if (lastSelector != selector) {
// update last selector slot position info
ds.facets[lastSelector] = (oldFacet & CLEAR_ADDRESS_MASK) | bytes20(ds.facets[lastSelector]);
}
delete ds.facets[selector];
uint256 oldSelectorCount = uint16(uint256(oldFacet));
// "oldSelectorCount >> 3" is a gas efficient division by 8 "oldSelectorCount / 8"
oldSelectorsSlotCount = oldSelectorCount >> 3;
// "oldSelectorCount & 7" is a gas efficient modulo by eight "oldSelectorCount % 8"
// " << 5 is the same as multiplying by 32 ( * 32)
oldSelectorInSlotPosition = (oldSelectorCount & 7) << 5;
}
if (oldSelectorsSlotCount != selectorSlotCount) {
bytes32 oldSelectorSlot = ds.selectorSlots[oldSelectorsSlotCount];
// clears the selector we are deleting and puts the last selector in its place.
oldSelectorSlot =
(oldSelectorSlot & ~(CLEAR_SELECTOR_MASK >> oldSelectorInSlotPosition)) |
(bytes32(lastSelector) >> oldSelectorInSlotPosition);
// update storage with the modified slot
ds.selectorSlots[oldSelectorsSlotCount] = oldSelectorSlot;
} else {
// clears the selector we are deleting and puts the last selector in its place.
_selectorSlot =
(_selectorSlot & ~(CLEAR_SELECTOR_MASK >> oldSelectorInSlotPosition)) |
(bytes32(lastSelector) >> oldSelectorInSlotPosition);
}
if (selectorInSlotIndex == 0) {
delete ds.selectorSlots[selectorSlotCount];
_selectorSlot = 0;
}
unchecked {
selectorIndex++;
}
}
_selectorCount = selectorSlotCount * 8 + selectorInSlotIndex;
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
return (_selectorCount, _selectorSlot);
}
function initializeDiamondCut(address _init, bytes memory _calldata) internal {
if (_init == address(0)) {
return;
}
enforceHasContractCode(_init, "LibDiamondCut: _init address has no code");
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up error
/// @solidity memory-safe-assembly
assembly {
let returndata_size := mload(error)
revert(add(32, error), returndata_size)
}
} else {
revert InitializationFunctionReverted(_init, _calldata);
}
}
}
function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
uint256 contractSize;
assembly {
contractSize := extcodesize(_contract)
}
require(contractSize > 0, _errorMessage);
}
}