Actual source code: matdiagonalcupm.hpp
1: #pragma once
3: #include <petscmat.h>
5: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
7: #include <petsc/private/cupminterface.hpp>
8: #include <petsc/private/cupmobject.hpp>
9: #include <petsc/private/deviceimpl.h>
10: #include <petsc/private/vecimpl.h>
11: #include <petsc/private/veccupmimpl.h>
12: #include <petsc/private/matimpl.h>
14: #include <thrust/device_ptr.h>
15: #include <thrust/iterator/zip_iterator.h>
16: #include <thrust/transform_reduce.h>
17: #include <thrust/tuple.h>
19: namespace Petsc
20: {
22: namespace device
23: {
25: namespace cupm
26: {
28: namespace impl
29: {
31: template <DeviceType T, typename VecType>
32: struct MatDiagonal_CUPM : vec::cupm::impl::Vec_CUPMBase<T, VecType> {
33: PETSC_CUPMOBJECT_HEADER(T);
34: using base_type = ::Petsc::vec::cupm::impl::Vec_CUPMBase<T, VecType>;
35: friend base_type;
37: static PetscErrorCode ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept;
38: static PetscErrorCode ANormSq(Mat A, Vec x, PetscReal *z) noexcept;
39: };
41: namespace detail
42: {
43: struct adot_transform {
44: using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar>;
46: PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return PetscConj(thrust::get<1>(tup)) * thrust::get<2>(tup) * thrust::get<0>(tup); }
47: };
48: } // namespace detail
50: template <Petsc::device::cupm::DeviceType T, typename VecType>
51: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept
52: {
53: PetscDeviceContext dctx;
54: cupmStream_t stream;
55: Mat_Diagonal *ctx = (Mat_Diagonal *)A->data;
56: PetscScalar zero = 0.;
57: const PetscInt n = x->map->n;
59: PetscFunctionBegin;
60: PetscCall(GetHandles_(&dctx, &stream));
62: const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
63: const auto ydptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, y).data());
64: const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());
66: // clang-format off
67: PetscCallThrust(
68: *z = THRUST_CALL(
69: thrust::transform_reduce,
70: stream,
71: thrust::make_zip_iterator(thrust::make_tuple(xdptr, ydptr, wdptr)),
72: thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, ydptr + n, wdptr + n)),
73: detail::adot_transform{},
74: zero,
75: thrust::plus<PetscScalar>()
76: )
77: );
78: // clang-format on
79: if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
80: PetscFunctionReturn(PETSC_SUCCESS);
81: }
83: namespace detail
84: {
85: struct anorm_transform {
86: using argument_type = thrust::tuple<PetscScalar, PetscScalar>;
88: PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return thrust::get<1>(tup) * PetscConj(thrust::get<0>(tup)) * thrust::get<0>(tup); }
89: };
90: } // namespace detail
92: template <Petsc::device::cupm::DeviceType T, typename VecType>
93: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ANormSq(Mat A, Vec x, PetscReal *z) noexcept
94: {
95: PetscDeviceContext dctx;
96: cupmStream_t stream;
97: Mat_Diagonal *ctx = (Mat_Diagonal *)A->data;
98: PetscScalar zero = 0., res;
99: const PetscInt n = x->map->n;
101: PetscFunctionBegin;
102: PetscCall(GetHandles_(&dctx, &stream));
104: const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
105: const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());
107: // clang-format off
108: PetscCallThrust(
109: res = THRUST_CALL(
110: thrust::transform_reduce,
111: stream,
112: thrust::make_zip_iterator(thrust::make_tuple(xdptr, wdptr)),
113: thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, wdptr + n)),
114: detail::anorm_transform{},
115: zero,
116: thrust::plus<PetscScalar>()
117: )
118: );
119: // clang-format on
120: *z = PetscRealPart(res);
121: if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
122: PetscFunctionReturn(PETSC_SUCCESS);
123: }
125: } // namespace impl
127: } // namespace cupm
129: } // namespace device
131: } // namespace Petsc