Distribution_Service/CC_SDK/Include/basic/ConcurrentMap.h
2025-11-11 17:46:19 +08:00

327 lines
9.7 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#ifndef CONCURRENTMAP_H
#define CONCURRENTMAP_H
#include <iostream>
#include <unordered_map>
#include <memory>
#include <stdexcept>
#include <iterator>
#include <atomic>
#include "ConcurrentLinkedList.h"
namespace CTL {
/**
* ConcurrentMap类提供了一个线程安全的映射实现。
* 它使用链表和哈希映射的组合来存储键值对,以实现并发访问和修改。
*
* @tparam Key 键的类型
* @tparam Value 值的类型
*/
template <typename Key, typename Value>
class ConcurrentMap {
public:
// 类型别名定义
using key_type = Key;
using mapped_type = Value;
using value_type = std::pair<const Key, Value>;
using size_type = std::size_t;
private:
// 链表节点结构体
struct Node {
value_type data;
std::atomic<Node*> next;
std::atomic<Node*> prev;
Node(const value_type& d) : data(d), next(nullptr), prev(nullptr) {}
};
using node_ptr = std::shared_ptr<Node>; // 使用 shared_ptr 管理节点
// 头尾指针和元素计数
std::atomic<Node*> head;
std::atomic<Node*> tail;
std::unordered_map<Key, node_ptr> map; // 使用 shared_ptr 代替 atomic
std::atomic<size_type> count;
public:
// 默认构造函数
ConcurrentMap() : head(nullptr), tail(nullptr), count(0) {}
// 析构函数,清理所有节点
~ConcurrentMap() {
clear();
}
// 清空链表和哈希映射中的所有节点
void clear() {
Node* current = head.load();
while (current) {
Node* next = current->next.load();
map.erase(current->data.first); // 从 map 中移除节点
current->next.store(nullptr);
current->prev.store(nullptr);
current = next;
}
head.store(nullptr);
tail.store(nullptr);
map.clear();
count.store(0);
}
// 获取当前元素数量
size_type size() const {
return count.load();
}
// 检查映射是否为空
bool IsEmpty() const {
return count.load() == 0;
}
// 插入或获取指定键的值
Value& operator[](const Key& key) {
auto newNode = std::make_shared<Node>(value_type(key, Value()));
node_ptr insertedNode = insertOrGetNode(key, newNode);
if (insertedNode != newNode) {
// 不需要显式删除 newNodeshared_ptr 会自动管理
}
return insertedNode->data.second;
}
// 获取指定键的值,如果键不存在则抛出异常
Value& get(const Key& key) {
node_ptr node = getNode(key);
if (!node) {
throw std::out_of_range("Key not found in CCLinkedMap");
}
return node->data.second;
}
// 获取指定键的值,如果键不存在则抛出异常(常量版本)
const Value& get(const Key& key) const {
node_ptr node = getNode(key);
if (!node) {
throw std::out_of_range("Key not found in CCLinkedMap");
}
return node->data.second;
}
// 移除指定键的节点
void remove(const Key& key) {
node_ptr nodeToRemove = getNode(key);
if (!nodeToRemove) {
return;
}
Node* prevNode = nodeToRemove->prev.load();
Node* nextNode = nodeToRemove->next.load();
if (prevNode) {
prevNode->next.store(nextNode);
} else {
head.store(nextNode);
}
if (nextNode) {
nextNode->prev.store(prevNode);
} else {
tail.store(prevNode);
}
map.erase(key);
count.fetch_sub(1);
}
std::list<Value> values() {
std::list<Value> result;
for (auto it = map.begin(); it != map.end(); ++it) {
result.push_back(it->second->data.second);
}
return result;
}
ArrayList<Value> toArrayList() {
ArrayList<Value> result;
for (auto it = map.begin(); it != map.end(); ++it) {
result.add(it->second->data.second);
}
return result;
}
ConcurrentLinkedList<Value> toConcurrentLinkedList() {
ConcurrentLinkedList<Value> result;
for (auto it = map.begin(); it != map.end(); ++it) {
result.add(it->second->data.second);
}
return result;
}
private:
// 插入新节点或获取已存在节点的指针
node_ptr insertOrGetNode(const Key& key, node_ptr newNode) {
while (true) {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
}
auto [itInserted, success] = map.emplace(key, newNode);
if (success) {
Node* prevTail = tail.load();
while (!tail.compare_exchange_weak(prevTail, newNode.get())) {
if (prevTail) {
prevTail->next.store(newNode.get());
newNode->prev.store(prevTail);
} else {
head.store(newNode.get());
}
}
count.fetch_add(1);
return newNode;
} else {
return itInserted->second;
}
}
}
// 获取指定键的节点指针如果键不存在则返回nullptr
node_ptr getNode(const Key& key) const {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
}
return nullptr;
}
public:
// 迭代器类用于遍历ConcurrentMap
class iterator {
public:
// 迭代器类型定义
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<const Key, Value>;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
iterator(Node* ptr) : current(ptr) {}
reference operator*() const {
return current->data;
}
pointer operator->() const {
return &(current->data);
}
iterator& operator++() {
if (current) {
current = current->next.load();
}
return *this;
}
iterator operator++(int) {
iterator tmp = *this;
++(*this);
return tmp;
}
friend bool operator==(const iterator& a, const iterator& b) {
return a.current == b.current;
}
friend bool operator!=(const iterator& a, const iterator& b) {
return a.current != b.current;
}
private:
Node* current;
};
// 常量迭代器类用于遍历ConcurrentMap
class const_iterator {
public:
// 迭代器类型定义
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<const Key, Value>;
using difference_type = std::ptrdiff_t;
using pointer = const value_type*;
using reference = const value_type&;
const_iterator(Node* ptr) : current(ptr) {}
reference operator*() const {
return current->data;
}
pointer operator->() const {
return &(current->data);
}
const_iterator& operator++() {
if (current) {
current = current->next.load();
}
return *this;
}
const_iterator operator++(int) {
const_iterator tmp = *this;
++(*this);
return tmp;
}
friend bool operator==(const const_iterator& a, const const_iterator& b) {
return a.current == b.current;
}
friend bool operator!=(const const_iterator& a, const const_iterator& b) {
return a.current != b.current;
}
private:
Node* current;
};
// 迭代器访问方法
iterator begin() {
return iterator(head.load());
}
iterator end() {
return iterator(nullptr);
}
const_iterator begin() const {
return const_iterator(head.load());
}
const_iterator end() const {
return const_iterator(nullptr);
}
const_iterator cbegin() const {
return const_iterator(head.load());
}
const_iterator cend() const {
return const_iterator(nullptr);
}
// 友元函数用于输出ConcurrentMap的内容
friend std::ostream& operator<<(std::ostream& os, const ConcurrentMap<Key, Value>& lm) {
os << "[";
bool first = true;
for (const auto& pair : lm) {
if (!first) {
os << ",";
}
os << "{ " << pair.first << " : " << pair.second << " }";
first = false;
}
os << "]";
return os;
}
};
}
#endif