C# 使用Vector256写了一个简单的帮助类Vector256Helper

发布于:2024-06-22 ⋅ 阅读:(122) ⋅ 点赞:(0)

 当数据量大的时候用普通代码计算非常耗时,这里简单利用simd加速处理

internal unsafe class Vector256Helper
    {
        /// <summary>
        /// 统计元素个数
        /// </summary>
        /// <param name="array"></param>
        /// <param name="elementToCount">需要统计的元素</param>
        /// <returns></returns>
        public static int Count(int[] array, int elementToCount)
        {
            int count = 0;
            int vectorSize = Vector256<int>.Count;
            int limit = array.Length - (array.Length % vectorSize);
            fixed (int* pArray = array)
            {
                int* ptr = pArray;
                var target = Vector256.Create(elementToCount);
                Vector256<int> equalMask = Vector256<int>.Zero;
                for (int i = 0; i < limit; i += vectorSize)
                {
                    Vector256<int> vector = Avx.LoadVector256(ptr + i);
                    equalMask += Avx2.CompareEqual(vector, target);
                }
                int* equalMaskPtr = (int*)&equalMask;
                for (int j = 0; j < vectorSize; j++)
                {
                    count += -*(equalMaskPtr + j);
                }
                // 处理剩余的元素
                for (int i = limit; i < array.Length; i++)
                {
                    if (*(pArray + i) == elementToCount)
                        count++;
                }
            }
            return count;
        }
        /// <summary>
        /// 求和
        /// </summary>
        /// <param name="buffer"></param>
        /// <returns></returns>
        public static int Sum(int[] buffer)
        {
            int vectorSize = Vector256<int>.Count;
            int sum4 = 0;
            Vector256<int> sumV = Vector256<int>.Zero;
            int j;
            int m = buffer.Length - vectorSize;
            fixed (int* p = buffer)
            {
                for (j = 0; j <= m; j += vectorSize)
                {
                    sumV += Avx2.LoadVector256(p + j);
                }
                int* ptr = (int*)&sumV;
                for (int i = 0; i < vectorSize; i++)
                {
                    sum4 += *(ptr + i);
                }
                // 处理剩余的元素
                for (; j < buffer.Length; j++)
                {
                    sum4 += *(p + j);
                }
            }
            return sum4;
        }
        /// <summary>
        /// 求平均数
        /// </summary>
        /// <param name="buffer"></param>
        /// <returns></returns>
        public double Avg(int[] buffer)
        {
            int sum = Sum(buffer);
            return sum / (double)buffer.Length;
        }
        /// <summary>
        /// 求最大值
        /// </summary>
        /// <param name="array"></param>
        /// <returns></returns>
        public static int Max(int[] array)
        {
            int vectorSize = Vector256<int>.Count;
            int limit = array.Length - (array.Length % vectorSize);

            fixed (int* pArray = array)
            {
                int maxElement = *pArray;
                Vector256<int> vectorMax = Avx2.LoadVector256(pArray);
                for (int i = 1; i < limit; i += vectorSize)
                {
                    Vector256<int> vector = Avx2.LoadVector256(pArray + i);
                    vectorMax = Avx2.Max(vectorMax, vector);
                }
                int* maxPtr = (int*)&vectorMax;
                for (int j = 0; j < vectorSize; j++)
                {
                    maxElement = Math.Max(maxElement, *(maxPtr + j));
                }
                for (int i = limit; i < array.Length; i++)
                {
                    maxElement = Math.Max(maxElement, *(pArray + i));
                }

                return maxElement;
            }
        }
        /// <summary>
        /// 求最小值
        /// </summary>
        /// <param name="array"></param>
        /// <returns></returns>
        public static int Min(int[] array)
        {
            int vectorSize = Vector256<int>.Count;
            int limit = array.Length - (array.Length % vectorSize);

            fixed (int* pArray = array)
            {
                int maxElement = *pArray;
                Vector256<int> vectorMin = Avx2.LoadVector256(pArray);
                for (int i = 1; i < limit; i += vectorSize)
                {
                    Vector256<int> vector = Avx2.LoadVector256(pArray + i);
                    vectorMin = Avx2.Min(vectorMin, vector);
                }
                int* maxPtr = (int*)&vectorMin;
                for (int j = 0; j < vectorSize; j++)
                {
                    maxElement = Math.Min(maxElement, *(maxPtr + j));
                }
                for (int i = limit; i < array.Length; i++)
                {
                    maxElement = Math.Min(maxElement, *(pArray + i));
                }

                return maxElement;
            }
        }
    }


网站公告

今日签到

点亮在社区的每一天
去签到