Tpetra parallel linear algebra Version of the Day
Loading...
Searching...
No Matches
Tpetra_Details_normImpl.hpp
Go to the documentation of this file.
1// @HEADER
2// ***********************************************************************
3//
4// Tpetra: Templated Linear Algebra Services Package
5// Copyright (2008) Sandia Corporation
6//
7// Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
8// the U.S. Government retains certain rights in this software.
9//
10// Redistribution and use in source and binary forms, with or without
11// modification, are permitted provided that the following conditions are
12// met:
13//
14// 1. Redistributions of source code must retain the above copyright
15// notice, this list of conditions and the following disclaimer.
16//
17// 2. Redistributions in binary form must reproduce the above copyright
18// notice, this list of conditions and the following disclaimer in the
19// documentation and/or other materials provided with the distribution.
20//
21// 3. Neither the name of the Corporation nor the names of the
22// contributors may be used to endorse or promote products derived from
23// this software without specific prior written permission.
24//
25// THIS SOFTWARE IS PROVIDED BY SANDIA CORPORATION "AS IS" AND ANY
26// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
28// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL SANDIA CORPORATION OR THE
29// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
30// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
31// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
32// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
33// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
34// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
35// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36//
37// Questions? Contact Michael A. Heroux (maherou@sandia.gov)
38//
39// ************************************************************************
40// @HEADER
41
42#ifndef TPETRA_DETAILS_NORMIMPL_HPP
43#define TPETRA_DETAILS_NORMIMPL_HPP
44
53
54#include "TpetraCore_config.h"
55#include "Kokkos_Core.hpp"
56#include "Teuchos_ArrayView.hpp"
57#include "Teuchos_CommHelpers.hpp"
58#include "KokkosBlas.hpp"
59#include "Kokkos_ArithTraits.hpp"
60
61#ifndef DOXYGEN_SHOULD_SKIP_THIS
62namespace Teuchos {
63 template<class T>
64 class ArrayView; // forward declaration
65 template<class OrdinalType>
66 class Comm; // forward declaration
67}
68#endif // DOXYGEN_SHOULD_SKIP_THIS
69
71// Declarations start here
73
74namespace Tpetra {
75namespace Details {
76
79 NORM_ONE, //<! Use the one-norm
80 NORM_TWO, //<! Use the two-norm
81 NORM_INF //<! Use the infinity-norm
82};
83
85template <class ValueType,
86 class ArrayLayout,
87 class DeviceType,
88 class MagnitudeType>
89void
91 const Kokkos::View<const ValueType**, ArrayLayout, DeviceType>& X,
93 const Teuchos::ArrayView<const size_t>& whichVecs,
94 const bool isConstantStride,
95 const bool isDistributed,
96 const Teuchos::Comm<int>* comm);
97
98} // namespace Details
99} // namespace Tpetra
100
102// Definitions start here
104
105namespace Tpetra {
106namespace Details {
107namespace Impl {
108
109template<class RV, class XMV>
110void
111lclNormImpl (const RV& normsOut,
112 const XMV& X,
113 const size_t numVecs,
114 const Teuchos::ArrayView<const size_t>& whichVecs,
115 const bool constantStride,
116 const EWhichNorm whichNorm)
117{
118 using Kokkos::ALL;
119 using Kokkos::subview;
120 using mag_type = typename RV::non_const_value_type;
121
122 static_assert (static_cast<int> (RV::Rank) == 1,
123 "Tpetra::MultiVector::lclNormImpl: "
124 "The first argument normsOut must have rank 1.");
125 static_assert (Kokkos::is_view<XMV>::value,
126 "Tpetra::MultiVector::lclNormImpl: "
127 "The second argument X is not a Kokkos::View.");
128 static_assert (static_cast<int> (XMV::Rank) == 2,
129 "Tpetra::MultiVector::lclNormImpl: "
130 "The second argument X must have rank 2.");
131
132 const size_t lclNumRows = static_cast<size_t> (X.extent (0));
133 TEUCHOS_TEST_FOR_EXCEPTION
134 (lclNumRows != 0 && constantStride &&
135 static_cast<size_t> (X.extent (1)) != numVecs,
136 std::logic_error, "Constant Stride X's dimensions are " << X.extent (0)
137 << " x " << X.extent (1) << ", which differ from the local dimensions "
138 << lclNumRows << " x " << numVecs << ". Please report this bug to "
139 "the Tpetra developers.");
140 TEUCHOS_TEST_FOR_EXCEPTION
141 (lclNumRows != 0 && ! constantStride &&
142 static_cast<size_t> (X.extent (1)) < numVecs,
143 std::logic_error, "Strided X's dimensions are " << X.extent (0) << " x "
144 << X.extent (1) << ", which are incompatible with the local dimensions "
145 << lclNumRows << " x " << numVecs << ". Please report this bug to "
146 "the Tpetra developers.");
147
148 if (lclNumRows == 0) {
149 const mag_type zeroMag = Kokkos::ArithTraits<mag_type>::zero ();
150 // DEEP_COPY REVIEW - VALUE-TO-DEVICE
151 using execution_space = typename RV::execution_space;
152 Kokkos::deep_copy (execution_space(), normsOut, zeroMag);
153 }
154 else { // lclNumRows != 0
155 if (constantStride) {
156 if (whichNorm == NORM_INF) {
157 KokkosBlas::nrminf (normsOut, X);
158 }
159 else if (whichNorm == NORM_ONE) {
160 KokkosBlas::nrm1 (normsOut, X);
161 }
162 else if (whichNorm == NORM_TWO) {
163 KokkosBlas::nrm2_squared (normsOut, X);
164 }
165 else {
166 TEUCHOS_TEST_FOR_EXCEPTION
167 (true, std::logic_error, "Should never get here!");
168 }
169 }
170 else { // not constant stride
171 // NOTE (mfh 15 Jul 2014, 11 Apr 2019) This does a kernel launch
172 // for every column. It might be better to have a kernel that
173 // does the work all at once. On the other hand, we don't
174 // prioritize performance of MultiVector views of noncontiguous
175 // columns.
176 for (size_t k = 0; k < numVecs; ++k) {
177 const size_t X_col = constantStride ? k : whichVecs[k];
178 if (whichNorm == NORM_INF) {
179 KokkosBlas::nrminf (subview (normsOut, k),
180 subview (X, ALL (), X_col));
181 }
182 else if (whichNorm == NORM_ONE) {
183 KokkosBlas::nrm1 (subview (normsOut, k),
184 subview (X, ALL (), X_col));
185 }
186 else if (whichNorm == NORM_TWO) {
187 KokkosBlas::nrm2_squared (subview (normsOut, k),
188 subview (X, ALL (), X_col));
189 }
190 else {
191 TEUCHOS_TEST_FOR_EXCEPTION
192 (true, std::logic_error, "Should never get here!");
193 }
194 } // for each column
195 } // constantStride
196 } // lclNumRows != 0
197}
198
199// Kokkos::parallel_for functor for applying square root to each
200// entry of a 1-D Kokkos::View.
201template<class ViewType>
202class SquareRootFunctor {
203public:
204 typedef typename ViewType::execution_space execution_space;
205 typedef typename ViewType::size_type size_type;
206
207 SquareRootFunctor (const ViewType& theView) :
208 theView_ (theView)
209 {}
210
211 KOKKOS_INLINE_FUNCTION void
212 operator() (const size_type& i) const
213 {
214 typedef typename ViewType::non_const_value_type value_type;
215 typedef Kokkos::Details::ArithTraits<value_type> KAT;
216 theView_(i) = KAT::sqrt (theView_(i));
217 }
218
219private:
220 ViewType theView_;
221};
222
223template<class RV>
224void
225gblNormImpl (const RV& normsOut,
226 const Teuchos::Comm<int>* const comm,
227 const bool distributed,
228 const EWhichNorm whichNorm)
229{
230 using Teuchos::REDUCE_MAX;
231 using Teuchos::REDUCE_SUM;
232 using Teuchos::reduceAll;
233 typedef typename RV::non_const_value_type mag_type;
234
235 const size_t numVecs = normsOut.extent (0);
236
237 // If the MultiVector is distributed over multiple processes, do
238 // the distributed (interprocess) part of the norm. We assume
239 // that the MPI implementation can read from and write to device
240 // memory.
241 //
242 // replaceMap() may have removed some processes. Those processes
243 // have a null Map. They must not participate in any collective
244 // operations. We ask first whether the Map is null, because
245 // isDistributed() defers that question to the Map. We still
246 // compute and return local norms for processes not participating
247 // in collective operations; those probably don't make any sense,
248 // but it doesn't hurt to do them, since it's illegal to call
249 // norm*() on those processes anyway.
250 if (distributed && comm != nullptr) {
251 // The calling process only participates in the collective if
252 // both the Map and its Comm on that process are nonnull.
253
254 const int nv = static_cast<int> (numVecs);
255 const bool commIsInterComm = ::Tpetra::Details::isInterComm (*comm);
256
257 if (commIsInterComm) {
258 RV lclNorms (Kokkos::ViewAllocateWithoutInitializing ("MV::normImpl lcl"), numVecs);
259 // DEEP_COPY REVIEW - DEVICE-TO-DEVICE
260 using execution_space = typename RV::execution_space;
261 Kokkos::deep_copy (execution_space(), lclNorms, normsOut);
262 const mag_type* const lclSum = lclNorms.data ();
263 mag_type* const gblSum = normsOut.data ();
264
265 if (whichNorm == NORM_INF) {
266 reduceAll<int, mag_type> (*comm, REDUCE_MAX, nv, lclSum, gblSum);
267 } else {
268 reduceAll<int, mag_type> (*comm, REDUCE_SUM, nv, lclSum, gblSum);
269 }
270 } else {
271 mag_type* const gblSum = normsOut.data ();
272 if (whichNorm == NORM_INF) {
273 reduceAll<int, mag_type> (*comm, REDUCE_MAX, nv, gblSum, gblSum);
274 } else {
275 reduceAll<int, mag_type> (*comm, REDUCE_SUM, nv, gblSum, gblSum);
276 }
277 }
278 }
279
280 if (whichNorm == NORM_TWO) {
281 // Replace the norm-squared results with their square roots in
282 // place, to get the final output. If the device memory and
283 // the host memory are the same, it probably doesn't pay to
284 // launch a parallel kernel for that, since there isn't enough
285 // parallelism for the typical MultiVector case.
286 const bool inHostMemory =
287 std::is_same<typename RV::memory_space,
288 typename RV::host_mirror_space::memory_space>::value;
289 if (inHostMemory) {
290 for (size_t j = 0; j < numVecs; ++j) {
291 normsOut(j) = Kokkos::Details::ArithTraits<mag_type>::sqrt (normsOut(j));
292 }
293 }
294 else {
295 // There's not as much parallelism now, but that's OK. The
296 // point of doing parallel dispatch here is to keep the norm
297 // results on the device, thus avoiding a copy to the host
298 // and back again.
299 SquareRootFunctor<RV> f (normsOut);
300 typedef typename RV::execution_space execution_space;
301 typedef Kokkos::RangePolicy<execution_space, size_t> range_type;
302 Kokkos::parallel_for (range_type (0, numVecs), f);
303 }
304 }
305}
306
307} // namespace Impl
308
309template <class ValueType,
310 class ArrayLayout,
311 class DeviceType,
312 class MagnitudeType>
313void
315 const Kokkos::View<const ValueType**, ArrayLayout, DeviceType>& X,
316 const EWhichNorm whichNorm,
317 const Teuchos::ArrayView<const size_t>& whichVecs,
318 const bool isConstantStride,
319 const bool isDistributed,
320 const Teuchos::Comm<int>* comm)
321{
322 using RV = Kokkos::View<MagnitudeType*, Kokkos::HostSpace>;
323 //using XMV = Kokkos::View<const ValueType**, ArrayLayout, DeviceType>;
324 //using pair_type = std::pair<size_t, size_t>;
325
326 const size_t numVecs = isConstantStride ?
327 static_cast<size_t> (X.extent (1)) :
328 static_cast<size_t> (whichVecs.size ());
329 if (numVecs == 0) {
330 return; // nothing to do
331 }
333
334 Impl::lclNormImpl (normsOut, X, numVecs, whichVecs,
335 isConstantStride, whichNorm);
336 Impl::gblNormImpl (normsOut, comm, isDistributed, whichNorm);
337}
338
339} // namespace Details
340} // namespace Tpetra
341
342#endif // TPETRA_DETAILS_NORMIMPL_HPP
Struct that holds views of the contents of a CrsMatrix.
Implementation details of Tpetra.
void normImpl(MagnitudeType norms[], const Kokkos::View< const ValueType **, ArrayLayout, DeviceType > &X, const EWhichNorm whichNorm, const Teuchos::ArrayView< const size_t > &whichVecs, const bool isConstantStride, const bool isDistributed, const Teuchos::Comm< int > *comm)
Implementation of MultiVector norms.
EWhichNorm
Input argument for normImpl() (which see).
bool isInterComm(const Teuchos::Comm< int > &)
Return true if and only if the input communicator wraps an MPI intercommunicator.
Namespace Tpetra contains the class and methods constituting the Tpetra library.