#include <iostream>
#include <memory>
#include <stdexcept>
#include <cuda_runtime.h>
template <typename T>
struct MemoryDeleter {
bool UseCUDA; // 成员变量用于标记是否使用CUDA
MemoryDeleter(bool useCUDA) : UseCUDA(useCUDA) {} // 构造函数初始化UseCUDA
void operator()(T* ptr) {
if (UseCUDA) {
cudaError_t cudaStatus = cudaFree(ptr);
if (cudaStatus != cudaSuccess) {
std::cerr << "CUDA memory free error: " << cudaGetErrorString(cudaStatus) << std::endl;
}
}
else {
delete[] ptr; // 使用delete[]释放CPU内存
}
}
};
template <typename T, bool UseCUDA>
using SharedMemoryPtr = std::conditional_t<UseCUDA, std::shared_ptr<T>, std::unique_ptr<T[], MemoryDeleter<T>>>;
template <typename T, bool UseCUDA>
class MemoryManager {
public:
static SharedMemoryPtr<T, UseCUDA> Allocate(size_t size);
static void Set(T* ptr, int value, size_t size);
static void Copy(T* dest, const T* src, size_t size);
};
template <typename T, bool UseCUDA>
SharedMemoryPtr<T, UseCUDA> MemoryManager<T, UseCUDA>::Allocate(size_t size) {
T* ptr = nullptr;
if constexpr (UseCUDA) {
cudaMalloc((T**)&ptr, size * sizeof(T));
}
else {
ptr = new T[size];
}
return SharedMemoryPtr<T, UseCUDA>(ptr, MemoryDeleter<T>(UseCUDA));
}
template <typename T, bool UseCUDA>
void MemoryManager<T, UseCUDA>::Set(T* ptr, int value, size_t size) {
if constexpr (UseCUDA) {
cudaMemset(ptr, value, size * sizeof(T));
}
else {
for (size_t i = 0; i < size; ++i) {
ptr[i] = static_cast<T>(value);
}
}
}
template <typename T, bool UseCUDA>
void MemoryManager<T, UseCUDA>::Copy(T* dest, const T* src, size_t size) {
if constexpr (UseCUDA) {
cudaMemcpy(dest, src, size * sizeof(T), cudaMemcpyHostToDevice);
}
else {
memcpy(dest, src, size * sizeof(T));
}
}
int main() {
int size = 512 * 512 * 500;
SharedMemoryPtr<float, true> ptr = MemoryManager<float, true>::Allocate(size);
int value = 0;
MemoryManager<float, true>::Set(ptr.get(), value, size);
// float hostData[512 * 512 * 100]={ 0 };
float* hostData = new float[size];
for (int i = 0; i < size; ++i) {
hostData[i] = static_cast<float>(i);
}
MemoryManager<float, true>::Copy(ptr.get(), hostData, size);
//ptr.reset();
return 0;
}