USB_Config_Vendor/CC_SDK/Include/basic/ConcurrentMap.h

327 lines
9.7 KiB
C
Raw Normal View History

2026-02-03 14:36:30 +08:00
#ifndef CONCURRENTMAP_H
#define CONCURRENTMAP_H
#include <iostream>
#include <unordered_map>
#include <memory>
#include <stdexcept>
#include <iterator>
#include <atomic>
#include "ConcurrentLinkedList.h"
namespace CTL {
/**
* ConcurrentMap类提供了一个线程安全的映射实现
* 使访
*
* @tparam Key
* @tparam Value
*/
template <typename Key, typename Value>
class ConcurrentMap {
public:
// 类型别名定义
using key_type = Key;
using mapped_type = Value;
using value_type = std::pair<const Key, Value>;
using size_type = std::size_t;
private:
// 链表节点结构体
struct Node {
value_type data;
std::atomic<Node*> next;
std::atomic<Node*> prev;
Node(const value_type& d) : data(d), next(nullptr), prev(nullptr) {}
};
using node_ptr = std::shared_ptr<Node>; // 使用 shared_ptr 管理节点
// 头尾指针和元素计数
std::atomic<Node*> head;
std::atomic<Node*> tail;
std::unordered_map<Key, node_ptr> map; // 使用 shared_ptr 代替 atomic
std::atomic<size_type> count;
public:
// 默认构造函数
ConcurrentMap() : head(nullptr), tail(nullptr), count(0) {}
// 析构函数,清理所有节点
~ConcurrentMap() {
clear();
}
// 清空链表和哈希映射中的所有节点
void clear() {
Node* current = head.load();
while (current) {
Node* next = current->next.load();
map.erase(current->data.first); // 从 map 中移除节点
current->next.store(nullptr);
current->prev.store(nullptr);
current = next;
}
head.store(nullptr);
tail.store(nullptr);
map.clear();
count.store(0);
}
// 获取当前元素数量
size_type size() const {
return count.load();
}
// 检查映射是否为空
bool IsEmpty() const {
return count.load() == 0;
}
// 插入或获取指定键的值
Value& operator[](const Key& key) {
auto newNode = std::make_shared<Node>(value_type(key, Value()));
node_ptr insertedNode = insertOrGetNode(key, newNode);
if (insertedNode != newNode) {
// 不需要显式删除 newNodeshared_ptr 会自动管理
}
return insertedNode->data.second;
}
// 获取指定键的值,如果键不存在则抛出异常
Value& get(const Key& key) {
node_ptr node = getNode(key);
if (!node) {
throw std::out_of_range("Key not found in CCLinkedMap");
}
return node->data.second;
}
// 获取指定键的值,如果键不存在则抛出异常(常量版本)
const Value& get(const Key& key) const {
node_ptr node = getNode(key);
if (!node) {
throw std::out_of_range("Key not found in CCLinkedMap");
}
return node->data.second;
}
// 移除指定键的节点
void remove(const Key& key) {
node_ptr nodeToRemove = getNode(key);
if (!nodeToRemove) {
return;
}
Node* prevNode = nodeToRemove->prev.load();
Node* nextNode = nodeToRemove->next.load();
if (prevNode) {
prevNode->next.store(nextNode);
} else {
head.store(nextNode);
}
if (nextNode) {
nextNode->prev.store(prevNode);
} else {
tail.store(prevNode);
}
map.erase(key);
count.fetch_sub(1);
}
std::list<Value> values() {
std::list<Value> result;
for (auto it = map.begin(); it != map.end(); ++it) {
result.push_back(it->second->data.second);
}
return result;
}
ArrayList<Value> toArrayList() {
ArrayList<Value> result;
for (auto it = map.begin(); it != map.end(); ++it) {
result.add(it->second->data.second);
}
return result;
}
ConcurrentLinkedList<Value> toConcurrentLinkedList() {
ConcurrentLinkedList<Value> result;
for (auto it = map.begin(); it != map.end(); ++it) {
result.add(it->second->data.second);
}
return result;
}
private:
// 插入新节点或获取已存在节点的指针
node_ptr insertOrGetNode(const Key& key, node_ptr newNode) {
while (true) {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
}
auto [itInserted, success] = map.emplace(key, newNode);
if (success) {
Node* prevTail = tail.load();
while (!tail.compare_exchange_weak(prevTail, newNode.get())) {
if (prevTail) {
prevTail->next.store(newNode.get());
newNode->prev.store(prevTail);
} else {
head.store(newNode.get());
}
}
count.fetch_add(1);
return newNode;
} else {
return itInserted->second;
}
}
}
// 获取指定键的节点指针如果键不存在则返回nullptr
node_ptr getNode(const Key& key) const {
auto it = map.find(key);
if (it != map.end()) {
return it->second;
}
return nullptr;
}
public:
// 迭代器类用于遍历ConcurrentMap
class iterator {
public:
// 迭代器类型定义
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<const Key, Value>;
using difference_type = std::ptrdiff_t;
using pointer = value_type*;
using reference = value_type&;
iterator(Node* ptr) : current(ptr) {}
reference operator*() const {
return current->data;
}
pointer operator->() const {
return &(current->data);
}
iterator& operator++() {
if (current) {
current = current->next.load();
}
return *this;
}
iterator operator++(int) {
iterator tmp = *this;
++(*this);
return tmp;
}
friend bool operator==(const iterator& a, const iterator& b) {
return a.current == b.current;
}
friend bool operator!=(const iterator& a, const iterator& b) {
return a.current != b.current;
}
private:
Node* current;
};
// 常量迭代器类用于遍历ConcurrentMap
class const_iterator {
public:
// 迭代器类型定义
using iterator_category = std::forward_iterator_tag;
using value_type = std::pair<const Key, Value>;
using difference_type = std::ptrdiff_t;
using pointer = const value_type*;
using reference = const value_type&;
const_iterator(Node* ptr) : current(ptr) {}
reference operator*() const {
return current->data;
}
pointer operator->() const {
return &(current->data);
}
const_iterator& operator++() {
if (current) {
current = current->next.load();
}
return *this;
}
const_iterator operator++(int) {
const_iterator tmp = *this;
++(*this);
return tmp;
}
friend bool operator==(const const_iterator& a, const const_iterator& b) {
return a.current == b.current;
}
friend bool operator!=(const const_iterator& a, const const_iterator& b) {
return a.current != b.current;
}
private:
Node* current;
};
// 迭代器访问方法
iterator begin() {
return iterator(head.load());
}
iterator end() {
return iterator(nullptr);
}
const_iterator begin() const {
return const_iterator(head.load());
}
const_iterator end() const {
return const_iterator(nullptr);
}
const_iterator cbegin() const {
return const_iterator(head.load());
}
const_iterator cend() const {
return const_iterator(nullptr);
}
// 友元函数用于输出ConcurrentMap的内容
friend std::ostream& operator<<(std::ostream& os, const ConcurrentMap<Key, Value>& lm) {
os << "[";
bool first = true;
for (const auto& pair : lm) {
if (!first) {
os << ",";
}
os << "{ " << pair.first << " : " << pair.second << " }";
first = false;
}
os << "]";
return os;
}
};
}
#endif