Unverified Commit bd077846 by Nicolás Venturo Committed by GitHub

Add EnumerableMap, refactor ERC721 (#2160)

* Implement AddressSet in terms of a generic Set

* Add Uint256Set

* Add EnumerableMap

* Fix wording on EnumerableSet docs and tests

* Refactor ERC721 using EnumerableSet and EnumerableMap

* Fix tests

* Fix linter error

* Gas optimization for EnumerableMap

* Gas optimization for EnumerableSet

* Remove often not-taken if from Enumerable data structures

* Fix failing test

* Gas optimization for EnumerableMap

* Fix linter errors

* Add comment for clarification

* Improve test naming

* Rename EnumerableMap.add to set

* Add overload for EnumerableMap.get with custom error message

* Improve Enumerable docs

* Rename Uint256Set to UintSet

* Add changelog entry
parent 0408e51a
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
### New features ### New features
* `AccessControl`: new contract for managing permissions in a system, replacement for `Ownable` and `Roles`. ([#2112](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2112)) * `AccessControl`: new contract for managing permissions in a system, replacement for `Ownable` and `Roles`. ([#2112](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2112))
* `SafeCast`: new functions to convert to and from signed and unsigned values: `toUint256` and `toInt256`. ([#2123](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2123)) * `SafeCast`: new functions to convert to and from signed and unsigned values: `toUint256` and `toInt256`. ([#2123](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2123))
* `EnumerableMap`: a new data structure for key-value pairs (like `mapping`) that can be iterated over. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160))
### Breaking changes ### Breaking changes
* `ERC721`: `burn(owner, tokenId)` was removed, use `burn(tokenId)` instead. ([#2125](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2125)) * `ERC721`: `burn(owner, tokenId)` was removed, use `burn(tokenId)` instead. ([#2125](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2125))
...@@ -30,6 +31,8 @@ ...@@ -30,6 +31,8 @@
* `ERC777`: removed `_callsTokensToSend` and `_callTokensReceived`. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134)) * `ERC777`: removed `_callsTokensToSend` and `_callTokensReceived`. ([#2134](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2134))
* `EnumerableSet`: renamed `get` to `at`. ([#2151](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2151)) * `EnumerableSet`: renamed `get` to `at`. ([#2151](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2151))
* `ERC165Checker`: functions no longer have a leading underscore. ([#2150](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2150)) * `ERC165Checker`: functions no longer have a leading underscore. ([#2150](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2150))
* `ERC721Metadata`, `ERC721Enumerable`: these contracts were removed, and their functionality merged into `ERC721`. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160))
* `ERC721`: added a constructor for `name` and `symbol`. ([#2160](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2160))
* `ERC20Detailed`: this contract was removed and its functionality merged into `ERC20`. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161)) * `ERC20Detailed`: this contract was removed and its functionality merged into `ERC20`. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161))
* `ERC20`: added a constructor for `name` and `symbol`. `decimals` now defaults to 18. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161)) * `ERC20`: added a constructor for `name` and `symbol`. `decimals` now defaults to 18. ([#2161](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/2161))
......
...@@ -13,10 +13,6 @@ contract ERC721Mock is ERC721 { ...@@ -13,10 +13,6 @@ contract ERC721Mock is ERC721 {
return _exists(tokenId); return _exists(tokenId);
} }
function tokensOfOwner(address owner) public view returns (uint256[] memory) {
return _tokensOfOwner(owner);
}
function setTokenURI(uint256 tokenId, string memory uri) public { function setTokenURI(uint256 tokenId, string memory uri) public {
_setTokenURI(tokenId, uri); _setTokenURI(tokenId, uri);
} }
......
pragma solidity ^0.6.0;
import "../utils/EnumerableMap.sol";
contract EnumerableMapMock {
using EnumerableMap for EnumerableMap.UintToAddressMap;
event OperationResult(bool result);
EnumerableMap.UintToAddressMap private _map;
function contains(uint256 key) public view returns (bool) {
return _map.contains(key);
}
function set(uint256 key, address value) public {
bool result = _map.set(key, value);
emit OperationResult(result);
}
function remove(uint256 key) public {
bool result = _map.remove(key);
emit OperationResult(result);
}
function length() public view returns (uint256) {
return _map.length();
}
function at(uint256 index) public view returns (uint256 key, address value) {
return _map.at(index);
}
function get(uint256 key) public view returns (address) {
return _map.get(key);
}
}
...@@ -5,7 +5,7 @@ import "../utils/EnumerableSet.sol"; ...@@ -5,7 +5,7 @@ import "../utils/EnumerableSet.sol";
contract EnumerableSetMock { contract EnumerableSetMock {
using EnumerableSet for EnumerableSet.AddressSet; using EnumerableSet for EnumerableSet.AddressSet;
event TransactionResult(bool result); event OperationResult(bool result);
EnumerableSet.AddressSet private _set; EnumerableSet.AddressSet private _set;
...@@ -15,16 +15,12 @@ contract EnumerableSetMock { ...@@ -15,16 +15,12 @@ contract EnumerableSetMock {
function add(address value) public { function add(address value) public {
bool result = _set.add(value); bool result = _set.add(value);
emit TransactionResult(result); emit OperationResult(result);
} }
function remove(address value) public { function remove(address value) public {
bool result = _set.remove(value); bool result = _set.remove(value);
emit TransactionResult(result); emit OperationResult(result);
}
function enumerate() public view returns (address[] memory) {
return _set.enumerate();
} }
function length() public view returns (uint256) { function length() public view returns (uint256) {
......
pragma solidity ^0.6.0;
library EnumerableMap {
// To implement this library for multiple types with as little code
// repetition as possible, we write it in terms of a generic Map type with
// bytes32 keys and values.
// The Map implementation uses private functions, and user-facing
// implementations (such as Uint256ToAddressMap) are just wrappers around
// the underlying Map.
// This means that we can only create new EnumerableMaps for types that fit
// in bytes32.
struct MapEntry {
bytes32 _key;
bytes32 _value;
}
struct Map {
// Storage of map keys and values
MapEntry[] _entries;
// Position of the entry defined by a key in the `entries` array, plus 1
// because index 0 means a key is not in the map.
mapping (bytes32 => uint256) _indexes;
}
/**
* @dev Adds a key-value pair to a map, or updates the value for an existing
* key. O(1).
*
* Returns true if the key was added to the map, that is if it was not
* already present.
*/
function _set(Map storage map, bytes32 key, bytes32 value) private returns (bool) {
// We read and store the key's index to prevent multiple reads from the same storage slot
uint256 keyIndex = map._indexes[key];
if (keyIndex == 0) { // Equivalent to !contains(map, key)
map._entries.push(MapEntry({ _key: key, _value: value }));
// The entry is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value
map._indexes[key] = map._entries.length;
return true;
} else {
map._entries[keyIndex - 1]._value = value;
return false;
}
}
/**
* @dev Removes a key-value pair from a map. O(1).
*
* Returns true if the key was removed from the map, that is if it was present.
*/
function _remove(Map storage map, bytes32 key) private returns (bool) {
// We read and store the key's index to prevent multiple reads from the same storage slot
uint256 keyIndex = map._indexes[key];
if (keyIndex != 0) { // Equivalent to contains(map, key)
// To delete a key-value pair from the _entries array in O(1), we swap the entry to delete with the last one
// in the array, and then remove the last entry (sometimes called as 'swap and pop').
// This modifies the order of the array, as noted in {at}.
uint256 toDeleteIndex = keyIndex - 1;
uint256 lastIndex = map._entries.length - 1;
// When the entry to delete is the last one, the swap operation is unnecessary. However, since this occurs
// so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement.
MapEntry storage lastEntry = map._entries[lastIndex];
// Move the last entry to the index where the entry to delete is
map._entries[toDeleteIndex] = lastEntry;
// Update the index for the moved entry
map._indexes[lastEntry._key] = toDeleteIndex + 1; // All indexes are 1-based
// Delete the slot where the moved entry was stored
map._entries.pop();
// Delete the index for the deleted slot
delete map._indexes[key];
return true;
} else {
return false;
}
}
/**
* @dev Returns true if the key is in the map. O(1).
*/
function _contains(Map storage map, bytes32 key) private view returns (bool) {
return map._indexes[key] != 0;
}
/**
* @dev Returns the number of key-value pairs in the map. O(1).
*/
function _length(Map storage map) private view returns (uint256) {
return map._entries.length;
}
/**
* @dev Returns the key-value pair stored at position `index` in the map. O(1).
*
* Note that there are no guarantees on the ordering of entries inside the
* array, and it may change when more entries are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function _at(Map storage map, uint256 index) private view returns (bytes32, bytes32) {
require(map._entries.length > index, "EnumerableMap: index out of bounds");
MapEntry storage entry = map._entries[index];
return (entry._key, entry._value);
}
/**
* @dev Returns the value associated with `key`. O(1).
*
* Requirements:
*
* - `key` must be in the map.
*/
function _get(Map storage map, bytes32 key) private view returns (bytes32) {
return _get(map, key, "EnumerableMap: nonexistent key");
}
/**
* @dev Same as {_get}, with a custom error message when `key` is not in the map.
*/
function _get(Map storage map, bytes32 key, string memory errorMessage) private view returns (bytes32) {
uint256 keyIndex = map._indexes[key];
require(keyIndex != 0, errorMessage); // Equivalent to contains(map, key)
return map._entries[keyIndex - 1]._value; // All indexes are 1-based
}
// UintToAddressMap
struct UintToAddressMap {
Map _inner;
}
/**
* @dev Adds a key-value pair to a map, or updates the value for an existing
* key. O(1).
*
* Returns true if the key was added to the map, that is if it was not
* already present.
*/
function set(UintToAddressMap storage map, uint256 key, address value) internal returns (bool) {
return _set(map._inner, bytes32(key), bytes32(uint256(value)));
}
/**
* @dev Removes a value from a set. O(1).
*
* Returns true if the key was removed from the map, that is if it was present.
*/
function remove(UintToAddressMap storage map, uint256 key) internal returns (bool) {
return _remove(map._inner, bytes32(key));
}
/**
* @dev Returns true if the key is in the map. O(1).
*/
function contains(UintToAddressMap storage map, uint256 key) internal view returns (bool) {
return _contains(map._inner, bytes32(key));
}
/**
* @dev Returns the number of elements in the map. O(1).
*/
function length(UintToAddressMap storage map) internal view returns (uint256) {
return _length(map._inner);
}
/**
* @dev Returns the element stored at position `index` in the set. O(1).
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function at(UintToAddressMap storage map, uint256 index) internal view returns (uint256, address) {
(bytes32 key, bytes32 value) = _at(map._inner, index);
return (uint256(key), address(uint256(value)));
}
/**
* @dev Returns the value associated with `key`. O(1).
*
* Requirements:
*
* - `key` must be in the map.
*/
function get(UintToAddressMap storage map, uint256 key) internal view returns (address) {
return address(uint256(_get(map._inner, bytes32(key))));
}
/**
* @dev Same as {get}, with a custom error message when `key` is not in the map.
*/
function get(UintToAddressMap storage map, uint256 key, string memory errorMessage) internal view returns (address) {
return address(uint256(_get(map._inner, bytes32(key), errorMessage)));
}
}
...@@ -18,24 +18,32 @@ pragma solidity ^0.6.0; ...@@ -18,24 +18,32 @@ pragma solidity ^0.6.0;
* @author Alberto Cuesta Cañada * @author Alberto Cuesta Cañada
*/ */
library EnumerableSet { library EnumerableSet {
// To implement this library for multiple types with as little code
// repetition as possible, we write it in terms of a generic Set type with
// bytes32 values.
// The Set implementation uses private functions, and user-facing
// implementations (such as AddressSet) are just wrappers around the
// underlying Set.
// This means that we can only create new EnumerableSets for types that fit
// in bytes32.
struct Set {
// Storage of set values
bytes32[] _values;
struct AddressSet {
address[] _values;
// Position of the value in the `values` array, plus 1 because index 0 // Position of the value in the `values` array, plus 1 because index 0
// means a value is not in the set. // means a value is not in the set.
mapping (address => uint256) _indexes; mapping (bytes32 => uint256) _indexes;
} }
/** /**
* @dev Add a value to a set. O(1). * @dev Add a value to a set. O(1).
* *
* Returns false if the value was already in the set. * Returns true if the value was added to the set, that is if it was not
* already present.
*/ */
function add(AddressSet storage set, address value) function _add(Set storage set, bytes32 value) private returns (bool) {
internal if (!_contains(set, value)) {
returns (bool)
{
if (!contains(set, value)) {
set._values.push(value); set._values.push(value);
// The value is stored at length-1, but we add 1 to all indexes // The value is stored at length-1, but we add 1 to all indexes
// and use 0 as a sentinel value // and use 0 as a sentinel value
...@@ -49,25 +57,30 @@ library EnumerableSet { ...@@ -49,25 +57,30 @@ library EnumerableSet {
/** /**
* @dev Removes a value from a set. O(1). * @dev Removes a value from a set. O(1).
* *
* Returns false if the value was not present in the set. * Returns true if the value was removed from the set, that is if it was
* present.
*/ */
function remove(AddressSet storage set, address value) function _remove(Set storage set, bytes32 value) private returns (bool) {
internal // We read and store the value's index to prevent multiple reads from the same storage slot
returns (bool) uint256 valueIndex = set._indexes[value];
{
if (contains(set, value)){ if (valueIndex != 0) { // Equivalent to contains(set, value)
uint256 toDeleteIndex = set._indexes[value] - 1; // To delete an element from the _values array in O(1), we swap the element to delete with the last one in
// the array, and then remove the last element (sometimes called as 'swap and pop').
// This modifies the order of the array, as noted in {at}.
uint256 toDeleteIndex = valueIndex - 1;
uint256 lastIndex = set._values.length - 1; uint256 lastIndex = set._values.length - 1;
// If the value we're deleting is the last one, we can just remove it without doing a swap // When the value to delete is the last one, the swap operation is unnecessary. However, since this occurs
if (lastIndex != toDeleteIndex) { // so rarely, we still do the swap anyway to avoid the gas cost of adding an 'if' statement.
address lastvalue = set._values[lastIndex];
bytes32 lastvalue = set._values[lastIndex];
// Move the last value to the index where the deleted value is // Move the last value to the index where the value to delete is
set._values[toDeleteIndex] = lastvalue; set._values[toDeleteIndex] = lastvalue;
// Update the index for the moved value // Update the index for the moved value
set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based set._indexes[lastvalue] = toDeleteIndex + 1; // All indexes are 1-based
}
// Delete the slot where the moved value was stored // Delete the slot where the moved value was stored
set._values.pop(); set._values.pop();
...@@ -84,44 +97,125 @@ library EnumerableSet { ...@@ -84,44 +97,125 @@ library EnumerableSet {
/** /**
* @dev Returns true if the value is in the set. O(1). * @dev Returns true if the value is in the set. O(1).
*/ */
function contains(AddressSet storage set, address value) function _contains(Set storage set, bytes32 value) private view returns (bool) {
internal
view
returns (bool)
{
return set._indexes[value] != 0; return set._indexes[value] != 0;
} }
/** /**
* @dev Returns an array with all values in the set. O(N). * @dev Returns the number of values on the set. O(1).
*/
function _length(Set storage set) private view returns (uint256) {
return set._values.length;
}
/**
* @dev Returns the value stored at position `index` in the set. O(1).
*
* Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function _at(Set storage set, uint256 index) private view returns (bytes32) {
require(set._values.length > index, "EnumerableSet: index out of bounds");
return set._values[index];
}
// AddressSet
struct AddressSet {
Set _inner;
}
/**
* @dev Add a value to a set. O(1).
*
* Returns true if the value was added to the set, that is if it was not
* already present.
*/
function add(AddressSet storage set, address value) internal returns (bool) {
return _add(set._inner, bytes32(uint256(value)));
}
/**
* @dev Removes a value from a set. O(1).
*
* Returns true if the value was removed from the set, that is if it was
* present.
*/
function remove(AddressSet storage set, address value) internal returns (bool) {
return _remove(set._inner, bytes32(uint256(value)));
}
/**
* @dev Returns true if the value is in the set. O(1).
*/
function contains(AddressSet storage set, address value) internal view returns (bool) {
return _contains(set._inner, bytes32(uint256(value)));
}
/**
* @dev Returns the number of values in the set. O(1).
*/
function length(AddressSet storage set) internal view returns (uint256) {
return _length(set._inner);
}
/**
* @dev Returns the value stored at position `index` in the set. O(1).
* *
* Note that there are no guarantees on the ordering of values inside the * Note that there are no guarantees on the ordering of values inside the
* array, and it may change when more values are added or removed. * array, and it may change when more values are added or removed.
*
* Requirements:
*
* - `index` must be strictly less than {length}.
*/
function at(AddressSet storage set, uint256 index) internal view returns (address) {
return address(uint256(_at(set._inner, index)));
}
* WARNING: This function may run out of gas on large sets: use {length} and // UintSet
* {at} instead in these cases.
struct UintSet {
Set _inner;
}
/**
* @dev Add a value to a set. O(1).
*
* Returns true if the value was added to the set, that is if it was not
* already present.
*/ */
function enumerate(AddressSet storage set) function add(UintSet storage set, uint256 value) internal returns (bool) {
internal return _add(set._inner, bytes32(value));
view
returns (address[] memory)
{
address[] memory output = new address[](set._values.length);
for (uint256 i; i < set._values.length; i++){
output[i] = set._values[i];
} }
return output;
/**
* @dev Removes a value from a set. O(1).
*
* Returns true if the value was removed from the set, that is if it was
* present.
*/
function remove(UintSet storage set, uint256 value) internal returns (bool) {
return _remove(set._inner, bytes32(value));
}
/**
* @dev Returns true if the value is in the set. O(1).
*/
function contains(UintSet storage set, uint256 value) internal view returns (bool) {
return _contains(set._inner, bytes32(value));
} }
/** /**
* @dev Returns the number of values on the set. O(1). * @dev Returns the number of values on the set. O(1).
*/ */
function length(AddressSet storage set) function length(UintSet storage set) internal view returns (uint256) {
internal return _length(set._inner);
view
returns (uint256)
{
return set._values.length;
} }
/** /**
...@@ -134,12 +228,7 @@ library EnumerableSet { ...@@ -134,12 +228,7 @@ library EnumerableSet {
* *
* - `index` must be strictly less than {length}. * - `index` must be strictly less than {length}.
*/ */
function at(AddressSet storage set, uint256 index) function at(UintSet storage set, uint256 index) internal view returns (uint256) {
internal return uint256(_at(set._inner, index));
view
returns (address)
{
require(set._values.length > index, "EnumerableSet: index out of bounds");
return set._values[index];
} }
} }
...@@ -31087,6 +31087,12 @@ ...@@ -31087,6 +31087,12 @@
"lodash._reinterpolate": "^3.0.0" "lodash._reinterpolate": "^3.0.0"
} }
}, },
"lodash.zip": {
"version": "4.2.0",
"resolved": "https://registry.npmjs.org/lodash.zip/-/lodash.zip-4.2.0.tgz",
"integrity": "sha1-7GZi5IlkCO1KtsVCo5kLcswIACA=",
"dev": true
},
"log-symbols": { "log-symbols": {
"version": "3.0.0", "version": "3.0.0",
"resolved": "https://registry.npmjs.org/log-symbols/-/log-symbols-3.0.0.tgz", "resolved": "https://registry.npmjs.org/log-symbols/-/log-symbols-3.0.0.tgz",
...@@ -176,27 +176,17 @@ describe('ERC721', function () { ...@@ -176,27 +176,17 @@ describe('ERC721', function () {
expect(await this.token.ownerOf(tokenId)).to.be.equal(this.toWhom); expect(await this.token.ownerOf(tokenId)).to.be.equal(this.toWhom);
}); });
it('emits a Transfer event', async function () {
expectEvent.inLogs(logs, 'Transfer', { from: owner, to: this.toWhom, tokenId: tokenId });
});
it('clears the approval for the token ID', async function () { it('clears the approval for the token ID', async function () {
expect(await this.token.getApproved(tokenId)).to.be.equal(ZERO_ADDRESS); expect(await this.token.getApproved(tokenId)).to.be.equal(ZERO_ADDRESS);
}); });
if (approved) { it('emits an Approval event', async function () {
it('emit only a transfer event', async function () { expectEvent.inLogs(logs, 'Approval', { owner, approved: ZERO_ADDRESS, tokenId: tokenId });
expectEvent.inLogs(logs, 'Transfer', {
from: owner,
to: this.toWhom,
tokenId: tokenId,
}); });
});
} else {
it('emits only a transfer event', async function () {
expectEvent.inLogs(logs, 'Transfer', {
from: owner,
to: this.toWhom,
tokenId: tokenId,
});
});
}
it('adjusts owners balances', async function () { it('adjusts owners balances', async function () {
expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1'); expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1');
...@@ -708,15 +698,6 @@ describe('ERC721', function () { ...@@ -708,15 +698,6 @@ describe('ERC721', function () {
}); });
}); });
describe('tokensOfOwner', function () {
it('returns total tokens of owner', async function () {
const tokenIds = await this.token.tokensOfOwner(owner);
expect(tokenIds.length).to.equal(2);
expect(tokenIds[0]).to.be.bignumber.equal(firstTokenId);
expect(tokenIds[1]).to.be.bignumber.equal(secondTokenId);
});
});
describe('totalSupply', function () { describe('totalSupply', function () {
it('returns total token supply', async function () { it('returns total token supply', async function () {
expect(await this.token.totalSupply()).to.be.bignumber.equal('2'); expect(await this.token.totalSupply()).to.be.bignumber.equal('2');
...@@ -733,7 +714,7 @@ describe('ERC721', function () { ...@@ -733,7 +714,7 @@ describe('ERC721', function () {
describe('when the index is greater than or equal to the total tokens owned by the given address', function () { describe('when the index is greater than or equal to the total tokens owned by the given address', function () {
it('reverts', async function () { it('reverts', async function () {
await expectRevert( await expectRevert(
this.token.tokenOfOwnerByIndex(owner, 2), 'ERC721Enumerable: owner index out of bounds' this.token.tokenOfOwnerByIndex(owner, 2), 'EnumerableSet: index out of bounds'
); );
}); });
}); });
...@@ -741,7 +722,7 @@ describe('ERC721', function () { ...@@ -741,7 +722,7 @@ describe('ERC721', function () {
describe('when the given address does not own any token', function () { describe('when the given address does not own any token', function () {
it('reverts', async function () { it('reverts', async function () {
await expectRevert( await expectRevert(
this.token.tokenOfOwnerByIndex(other, 0), 'ERC721Enumerable: owner index out of bounds' this.token.tokenOfOwnerByIndex(other, 0), 'EnumerableSet: index out of bounds'
); );
}); });
}); });
...@@ -764,7 +745,7 @@ describe('ERC721', function () { ...@@ -764,7 +745,7 @@ describe('ERC721', function () {
it('returns empty collection for original owner', async function () { it('returns empty collection for original owner', async function () {
expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('0'); expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('0');
await expectRevert( await expectRevert(
this.token.tokenOfOwnerByIndex(owner, 0), 'ERC721Enumerable: owner index out of bounds' this.token.tokenOfOwnerByIndex(owner, 0), 'EnumerableSet: index out of bounds'
); );
}); });
}); });
...@@ -781,7 +762,7 @@ describe('ERC721', function () { ...@@ -781,7 +762,7 @@ describe('ERC721', function () {
it('should revert if index is greater than supply', async function () { it('should revert if index is greater than supply', async function () {
await expectRevert( await expectRevert(
this.token.tokenByIndex(2), 'ERC721Enumerable: global index out of bounds' this.token.tokenByIndex(2), 'EnumerableMap: index out of bounds'
); );
}); });
...@@ -790,7 +771,7 @@ describe('ERC721', function () { ...@@ -790,7 +771,7 @@ describe('ERC721', function () {
const newTokenId = new BN(300); const newTokenId = new BN(300);
const anotherNewTokenId = new BN(400); const anotherNewTokenId = new BN(400);
await this.token.burn(tokenId, { from: owner }); await this.token.burn(tokenId);
await this.token.mint(newOwner, newTokenId); await this.token.mint(newOwner, newTokenId);
await this.token.mint(newOwner, anotherNewTokenId); await this.token.mint(newOwner, anotherNewTokenId);
...@@ -865,6 +846,10 @@ describe('ERC721', function () { ...@@ -865,6 +846,10 @@ describe('ERC721', function () {
expectEvent.inLogs(this.logs, 'Transfer', { from: owner, to: ZERO_ADDRESS, tokenId: firstTokenId }); expectEvent.inLogs(this.logs, 'Transfer', { from: owner, to: ZERO_ADDRESS, tokenId: firstTokenId });
}); });
it('emits an Approval event', function () {
expectEvent.inLogs(this.logs, 'Approval', { owner, approved: ZERO_ADDRESS, tokenId: firstTokenId });
});
it('deletes the token', async function () { it('deletes the token', async function () {
expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1'); expect(await this.token.balanceOf(owner)).to.be.bignumber.equal('1');
await expectRevert( await expectRevert(
...@@ -884,7 +869,7 @@ describe('ERC721', function () { ...@@ -884,7 +869,7 @@ describe('ERC721', function () {
await this.token.burn(secondTokenId, { from: owner }); await this.token.burn(secondTokenId, { from: owner });
expect(await this.token.totalSupply()).to.be.bignumber.equal('0'); expect(await this.token.totalSupply()).to.be.bignumber.equal('0');
await expectRevert( await expectRevert(
this.token.tokenByIndex(0), 'ERC721Enumerable: global index out of bounds' this.token.tokenByIndex(0), 'EnumerableMap: index out of bounds'
); );
}); });
......
const { accounts, contract } = require('@openzeppelin/test-environment');
const { BN, expectEvent } = require('@openzeppelin/test-helpers');
const { expect } = require('chai');
const zip = require('lodash.zip');
const EnumerableMapMock = contract.fromArtifact('EnumerableMapMock');
describe('EnumerableMap', function () {
const [ accountA, accountB, accountC ] = accounts;
const keyA = new BN('7891');
const keyB = new BN('451');
const keyC = new BN('9592328');
beforeEach(async function () {
this.map = await EnumerableMapMock.new();
});
async function expectMembersMatch (map, keys, values) {
expect(keys.length).to.equal(values.length);
await Promise.all(keys.map(async key =>
expect(await map.contains(key)).to.equal(true)
));
expect(await map.length()).to.bignumber.equal(keys.length.toString());
expect(await Promise.all(keys.map(key =>
map.get(key)
))).to.have.same.members(values);
// To compare key-value pairs, we zip keys and values, and convert BNs to
// strings to workaround Chai limitations when dealing with nested arrays
expect(await Promise.all([...Array(keys.length).keys()].map(async (index) => {
const entry = await map.at(index);
return [entry.key.toString(), entry.value];
}))).to.have.same.deep.members(
zip(keys.map(k => k.toString()), values)
);
}
it('starts empty', async function () {
expect(await this.map.contains(keyA)).to.equal(false);
await expectMembersMatch(this.map, [], []);
});
it('adds a key', async function () {
const receipt = await this.map.set(keyA, accountA);
expectEvent(receipt, 'OperationResult', { result: true });
await expectMembersMatch(this.map, [keyA], [accountA]);
});
it('adds several keys', async function () {
await this.map.set(keyA, accountA);
await this.map.set(keyB, accountB);
await expectMembersMatch(this.map, [keyA, keyB], [accountA, accountB]);
expect(await this.map.contains(keyC)).to.equal(false);
});
it('returns false when adding keys already in the set', async function () {
await this.map.set(keyA, accountA);
const receipt = (await this.map.set(keyA, accountA));
expectEvent(receipt, 'OperationResult', { result: false });
await expectMembersMatch(this.map, [keyA], [accountA]);
});
it('updates values for keys already in the set', async function () {
await this.map.set(keyA, accountA);
await this.map.set(keyA, accountB);
await expectMembersMatch(this.map, [keyA], [accountB]);
});
it('removes added keys', async function () {
await this.map.set(keyA, accountA);
const receipt = await this.map.remove(keyA);
expectEvent(receipt, 'OperationResult', { result: true });
expect(await this.map.contains(keyA)).to.equal(false);
await expectMembersMatch(this.map, [], []);
});
it('returns false when removing keys not in the set', async function () {
const receipt = await this.map.remove(keyA);
expectEvent(receipt, 'OperationResult', { result: false });
expect(await this.map.contains(keyA)).to.equal(false);
});
it('adds and removes multiple keys', async function () {
// []
await this.map.set(keyA, accountA);
await this.map.set(keyC, accountC);
// [A, C]
await this.map.remove(keyA);
await this.map.remove(keyB);
// [C]
await this.map.set(keyB, accountB);
// [C, B]
await this.map.set(keyA, accountA);
await this.map.remove(keyC);
// [A, B]
await this.map.set(keyA, accountA);
await this.map.set(keyB, accountB);
// [A, B]
await this.map.set(keyC, accountC);
await this.map.remove(keyA);
// [B, C]
await this.map.set(keyA, accountA);
await this.map.remove(keyB);
// [A, C]
await expectMembersMatch(this.map, [keyA, keyC], [accountA, accountC]);
expect(await this.map.contains(keyB)).to.equal(false);
});
});
...@@ -11,18 +11,16 @@ describe('EnumerableSet', function () { ...@@ -11,18 +11,16 @@ describe('EnumerableSet', function () {
this.set = await EnumerableSetMock.new(); this.set = await EnumerableSetMock.new();
}); });
async function expectMembersMatch (set, members) { async function expectMembersMatch (set, values) {
await Promise.all(members.map(async account => await Promise.all(values.map(async account =>
expect(await set.contains(account)).to.equal(true) expect(await set.contains(account)).to.equal(true)
)); ));
expect(await set.enumerate()).to.have.same.members(members); expect(await set.length()).to.bignumber.equal(values.length.toString());
expect(await set.length()).to.bignumber.equal(members.length.toString()); expect(await Promise.all([...Array(values.length).keys()].map(index =>
expect(await Promise.all([...Array(members.length).keys()].map(index =>
set.at(index) set.at(index)
))).to.have.same.members(members); ))).to.have.same.members(values);
} }
it('starts empty', async function () { it('starts empty', async function () {
...@@ -33,7 +31,7 @@ describe('EnumerableSet', function () { ...@@ -33,7 +31,7 @@ describe('EnumerableSet', function () {
it('adds a value', async function () { it('adds a value', async function () {
const receipt = await this.set.add(accountA); const receipt = await this.set.add(accountA);
expectEvent(receipt, 'TransactionResult', { result: true }); expectEvent(receipt, 'OperationResult', { result: true });
await expectMembersMatch(this.set, [accountA]); await expectMembersMatch(this.set, [accountA]);
}); });
...@@ -46,11 +44,11 @@ describe('EnumerableSet', function () { ...@@ -46,11 +44,11 @@ describe('EnumerableSet', function () {
expect(await this.set.contains(accountC)).to.equal(false); expect(await this.set.contains(accountC)).to.equal(false);
}); });
it('returns false when adding elements already in the set', async function () { it('returns false when adding values already in the set', async function () {
await this.set.add(accountA); await this.set.add(accountA);
const receipt = (await this.set.add(accountA)); const receipt = (await this.set.add(accountA));
expectEvent(receipt, 'TransactionResult', { result: false }); expectEvent(receipt, 'OperationResult', { result: false });
await expectMembersMatch(this.set, [accountA]); await expectMembersMatch(this.set, [accountA]);
}); });
...@@ -63,15 +61,15 @@ describe('EnumerableSet', function () { ...@@ -63,15 +61,15 @@ describe('EnumerableSet', function () {
await this.set.add(accountA); await this.set.add(accountA);
const receipt = await this.set.remove(accountA); const receipt = await this.set.remove(accountA);
expectEvent(receipt, 'TransactionResult', { result: true }); expectEvent(receipt, 'OperationResult', { result: true });
expect(await this.set.contains(accountA)).to.equal(false); expect(await this.set.contains(accountA)).to.equal(false);
await expectMembersMatch(this.set, []); await expectMembersMatch(this.set, []);
}); });
it('returns false when removing elements not in the set', async function () { it('returns false when removing values not in the set', async function () {
const receipt = await this.set.remove(accountA); const receipt = await this.set.remove(accountA);
expectEvent(receipt, 'TransactionResult', { result: false }); expectEvent(receipt, 'OperationResult', { result: false });
expect(await this.set.contains(accountA)).to.equal(false); expect(await this.set.contains(accountA)).to.equal(false);
}); });
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment