327 lines
9.7 KiB
C++
327 lines
9.7 KiB
C++
#ifndef CONCURRENTMAP_H
|
||
#define CONCURRENTMAP_H
|
||
|
||
#include <iostream>
|
||
#include <unordered_map>
|
||
#include <memory>
|
||
#include <stdexcept>
|
||
#include <iterator>
|
||
#include <atomic>
|
||
|
||
#include "ConcurrentLinkedList.h"
|
||
|
||
namespace CTL {
|
||
/**
|
||
* ConcurrentMap类提供了一个线程安全的映射实现。
|
||
* 它使用链表和哈希映射的组合来存储键值对,以实现并发访问和修改。
|
||
*
|
||
* @tparam Key 键的类型
|
||
* @tparam Value 值的类型
|
||
*/
|
||
template <typename Key, typename Value>
|
||
class ConcurrentMap {
|
||
public:
|
||
// 类型别名定义
|
||
using key_type = Key;
|
||
using mapped_type = Value;
|
||
using value_type = std::pair<const Key, Value>;
|
||
using size_type = std::size_t;
|
||
|
||
private:
|
||
// 链表节点结构体
|
||
struct Node {
|
||
value_type data;
|
||
std::atomic<Node*> next;
|
||
std::atomic<Node*> prev;
|
||
Node(const value_type& d) : data(d), next(nullptr), prev(nullptr) {}
|
||
};
|
||
|
||
using node_ptr = std::shared_ptr<Node>; // 使用 shared_ptr 管理节点
|
||
|
||
// 头尾指针和元素计数
|
||
std::atomic<Node*> head;
|
||
std::atomic<Node*> tail;
|
||
std::unordered_map<Key, node_ptr> map; // 使用 shared_ptr 代替 atomic
|
||
std::atomic<size_type> count;
|
||
|
||
public:
|
||
// 默认构造函数
|
||
ConcurrentMap() : head(nullptr), tail(nullptr), count(0) {}
|
||
|
||
// 析构函数,清理所有节点
|
||
~ConcurrentMap() {
|
||
clear();
|
||
}
|
||
|
||
// 清空链表和哈希映射中的所有节点
|
||
void clear() {
|
||
Node* current = head.load();
|
||
while (current) {
|
||
Node* next = current->next.load();
|
||
map.erase(current->data.first); // 从 map 中移除节点
|
||
current->next.store(nullptr);
|
||
current->prev.store(nullptr);
|
||
current = next;
|
||
}
|
||
head.store(nullptr);
|
||
tail.store(nullptr);
|
||
map.clear();
|
||
count.store(0);
|
||
}
|
||
|
||
// 获取当前元素数量
|
||
size_type size() const {
|
||
return count.load();
|
||
}
|
||
|
||
// 检查映射是否为空
|
||
bool IsEmpty() const {
|
||
return count.load() == 0;
|
||
}
|
||
|
||
// 插入或获取指定键的值
|
||
Value& operator[](const Key& key) {
|
||
auto newNode = std::make_shared<Node>(value_type(key, Value()));
|
||
node_ptr insertedNode = insertOrGetNode(key, newNode);
|
||
if (insertedNode != newNode) {
|
||
// 不需要显式删除 newNode,shared_ptr 会自动管理
|
||
}
|
||
return insertedNode->data.second;
|
||
}
|
||
|
||
// 获取指定键的值,如果键不存在则抛出异常
|
||
Value& get(const Key& key) {
|
||
node_ptr node = getNode(key);
|
||
if (!node) {
|
||
throw std::out_of_range("Key not found in CCLinkedMap");
|
||
}
|
||
return node->data.second;
|
||
}
|
||
|
||
// 获取指定键的值,如果键不存在则抛出异常(常量版本)
|
||
const Value& get(const Key& key) const {
|
||
node_ptr node = getNode(key);
|
||
if (!node) {
|
||
throw std::out_of_range("Key not found in CCLinkedMap");
|
||
}
|
||
return node->data.second;
|
||
}
|
||
|
||
// 移除指定键的节点
|
||
void remove(const Key& key) {
|
||
node_ptr nodeToRemove = getNode(key);
|
||
if (!nodeToRemove) {
|
||
return;
|
||
}
|
||
|
||
Node* prevNode = nodeToRemove->prev.load();
|
||
Node* nextNode = nodeToRemove->next.load();
|
||
|
||
if (prevNode) {
|
||
prevNode->next.store(nextNode);
|
||
} else {
|
||
head.store(nextNode);
|
||
}
|
||
|
||
if (nextNode) {
|
||
nextNode->prev.store(prevNode);
|
||
} else {
|
||
tail.store(prevNode);
|
||
}
|
||
|
||
map.erase(key);
|
||
count.fetch_sub(1);
|
||
}
|
||
std::list<Value> values() {
|
||
std::list<Value> result;
|
||
for (auto it = map.begin(); it != map.end(); ++it) {
|
||
result.push_back(it->second->data.second);
|
||
}
|
||
return result;
|
||
}
|
||
ArrayList<Value> toArrayList() {
|
||
ArrayList<Value> result;
|
||
for (auto it = map.begin(); it != map.end(); ++it) {
|
||
result.add(it->second->data.second);
|
||
}
|
||
return result;
|
||
}
|
||
ConcurrentLinkedList<Value> toConcurrentLinkedList() {
|
||
ConcurrentLinkedList<Value> result;
|
||
for (auto it = map.begin(); it != map.end(); ++it) {
|
||
result.add(it->second->data.second);
|
||
}
|
||
return result;
|
||
}
|
||
private:
|
||
// 插入新节点或获取已存在节点的指针
|
||
node_ptr insertOrGetNode(const Key& key, node_ptr newNode) {
|
||
while (true) {
|
||
auto it = map.find(key);
|
||
if (it != map.end()) {
|
||
return it->second;
|
||
}
|
||
|
||
auto [itInserted, success] = map.emplace(key, newNode);
|
||
if (success) {
|
||
Node* prevTail = tail.load();
|
||
while (!tail.compare_exchange_weak(prevTail, newNode.get())) {
|
||
if (prevTail) {
|
||
prevTail->next.store(newNode.get());
|
||
newNode->prev.store(prevTail);
|
||
} else {
|
||
head.store(newNode.get());
|
||
}
|
||
}
|
||
count.fetch_add(1);
|
||
return newNode;
|
||
} else {
|
||
return itInserted->second;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 获取指定键的节点指针,如果键不存在则返回nullptr
|
||
node_ptr getNode(const Key& key) const {
|
||
auto it = map.find(key);
|
||
if (it != map.end()) {
|
||
return it->second;
|
||
}
|
||
return nullptr;
|
||
}
|
||
|
||
public:
|
||
// 迭代器类,用于遍历ConcurrentMap
|
||
class iterator {
|
||
public:
|
||
// 迭代器类型定义
|
||
using iterator_category = std::forward_iterator_tag;
|
||
using value_type = std::pair<const Key, Value>;
|
||
using difference_type = std::ptrdiff_t;
|
||
using pointer = value_type*;
|
||
using reference = value_type&;
|
||
|
||
iterator(Node* ptr) : current(ptr) {}
|
||
|
||
reference operator*() const {
|
||
return current->data;
|
||
}
|
||
|
||
pointer operator->() const {
|
||
return &(current->data);
|
||
}
|
||
|
||
iterator& operator++() {
|
||
if (current) {
|
||
current = current->next.load();
|
||
}
|
||
return *this;
|
||
}
|
||
|
||
iterator operator++(int) {
|
||
iterator tmp = *this;
|
||
++(*this);
|
||
return tmp;
|
||
}
|
||
|
||
friend bool operator==(const iterator& a, const iterator& b) {
|
||
return a.current == b.current;
|
||
}
|
||
|
||
friend bool operator!=(const iterator& a, const iterator& b) {
|
||
return a.current != b.current;
|
||
}
|
||
|
||
private:
|
||
Node* current;
|
||
};
|
||
|
||
// 常量迭代器类,用于遍历ConcurrentMap
|
||
class const_iterator {
|
||
public:
|
||
// 迭代器类型定义
|
||
using iterator_category = std::forward_iterator_tag;
|
||
using value_type = std::pair<const Key, Value>;
|
||
using difference_type = std::ptrdiff_t;
|
||
using pointer = const value_type*;
|
||
using reference = const value_type&;
|
||
|
||
const_iterator(Node* ptr) : current(ptr) {}
|
||
|
||
reference operator*() const {
|
||
return current->data;
|
||
}
|
||
|
||
pointer operator->() const {
|
||
return &(current->data);
|
||
}
|
||
|
||
const_iterator& operator++() {
|
||
if (current) {
|
||
current = current->next.load();
|
||
}
|
||
return *this;
|
||
}
|
||
|
||
const_iterator operator++(int) {
|
||
const_iterator tmp = *this;
|
||
++(*this);
|
||
return tmp;
|
||
}
|
||
|
||
friend bool operator==(const const_iterator& a, const const_iterator& b) {
|
||
return a.current == b.current;
|
||
}
|
||
|
||
friend bool operator!=(const const_iterator& a, const const_iterator& b) {
|
||
return a.current != b.current;
|
||
}
|
||
|
||
private:
|
||
Node* current;
|
||
};
|
||
|
||
// 迭代器访问方法
|
||
iterator begin() {
|
||
return iterator(head.load());
|
||
}
|
||
|
||
iterator end() {
|
||
return iterator(nullptr);
|
||
}
|
||
|
||
const_iterator begin() const {
|
||
return const_iterator(head.load());
|
||
}
|
||
|
||
const_iterator end() const {
|
||
return const_iterator(nullptr);
|
||
}
|
||
|
||
const_iterator cbegin() const {
|
||
return const_iterator(head.load());
|
||
}
|
||
|
||
const_iterator cend() const {
|
||
return const_iterator(nullptr);
|
||
}
|
||
|
||
// 友元函数,用于输出ConcurrentMap的内容
|
||
friend std::ostream& operator<<(std::ostream& os, const ConcurrentMap<Key, Value>& lm) {
|
||
os << "[";
|
||
bool first = true;
|
||
for (const auto& pair : lm) {
|
||
if (!first) {
|
||
os << ",";
|
||
}
|
||
os << "{ " << pair.first << " : " << pair.second << " }";
|
||
first = false;
|
||
}
|
||
os << "]";
|
||
return os;
|
||
}
|
||
};
|
||
}
|
||
|
||
#endif
|