50#ifndef _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
51#define _ZOLTAN2_TPETRAROWMATRIXADAPTER_HPP_
57#include <Tpetra_RowMatrix.hpp>
75template <
typename User,
typename UserCoord=User>
79#ifndef DOXYGEN_SHOULD_SKIP_THIS
87 typedef UserCoord userCoord_t;
100 int nWeightsPerRow=0);
152 return matrix_->getLocalNumRows();
156 return matrix_->getLocalNumCols();
160 return matrix_->getLocalNumEntries();
167 ArrayView<const gno_t> rowView = rowMap_->getLocalElementList();
168 rowIds = rowView.getRawPtr();
171 void getCRSView(ArrayRCP<const offset_t> &offsets, ArrayRCP<const gno_t> &colIds)
const
178 ArrayRCP<const gno_t> &colIds,
179 ArrayRCP<const scalar_t> &values)
const
192 if(idx<0 || idx >= nWeightsPerRow_)
194 std::ostringstream emsg;
195 emsg << __FILE__ <<
":" << __LINE__
196 <<
" Invalid row weight index " << idx << std::endl;
197 throw std::runtime_error(emsg.str());
202 rowWeights_[idx].getStridedList(length,
weights, stride);
207 template <
typename Adapter>
209 const PartitioningSolution<Adapter> &solution)
const;
211 template <
typename Adapter>
213 const PartitioningSolution<Adapter> &solution)
const;
217 RCP<const User> matrix_;
218 RCP<const Tpetra::Map<lno_t, gno_t, node_t> > rowMap_;
219 RCP<const Tpetra::Map<lno_t, gno_t, node_t> > colMap_;
220 ArrayRCP<offset_t> offset_;
221 ArrayRCP<gno_t> columnIds_;
222 ArrayRCP<scalar_t> values_;
225 ArrayRCP<StridedData<lno_t, scalar_t> > rowWeights_;
226 ArrayRCP<bool> numNzWeight_;
228 bool mayHaveDiagonalEntries;
230 RCP<User> doMigration(
const User &from,
size_t numLocalRows,
231 const gno_t *myNewRows)
const;
238template <
typename User,
typename UserCoord>
240 const RCP<const User> &inmatrix,
int nWeightsPerRow):
241 matrix_(inmatrix), rowMap_(), colMap_(),
242 offset_(), columnIds_(),
243 nWeightsPerRow_(nWeightsPerRow), rowWeights_(), numNzWeight_(),
244 mayHaveDiagonalEntries(true)
246 typedef StridedData<lno_t,scalar_t> input_t;
248 rowMap_ = matrix_->getRowMap();
249 colMap_ = matrix_->getColMap();
251 size_t nrows = matrix_->getLocalNumRows();
252 size_t nnz = matrix_->getLocalNumEntries();
253 size_t maxnumentries =
254 matrix_->getLocalMaxNumRowEntries();
256 offset_.resize(nrows+1, 0);
257 columnIds_.resize(nnz);
259 typename User::nonconst_local_inds_host_view_type indices(
"indices", maxnumentries);
260 typename User::nonconst_values_host_view_type nzs(
"nzs", maxnumentries);
263 for (
size_t i=0; i < nrows; i++){
265 matrix_->getLocalRowCopy(row, indices, nzs, nnz);
266 for (
size_t j=0; j < nnz; j++){
267 values_[next] = nzs[j];
270 columnIds_[next++] = colMap_->getGlobalElement(indices[j]);
272 offset_[i+1] = offset_[i] + nnz;
275 if (nWeightsPerRow_ > 0){
276 rowWeights_ = arcp(
new input_t [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
277 numNzWeight_ = arcp(
new bool [nWeightsPerRow_], 0, nWeightsPerRow_,
true);
278 for (
int i=0; i < nWeightsPerRow_; i++)
279 numNzWeight_[i] =
false;
284template <
typename User,
typename UserCoord>
286 const scalar_t *weightVal,
int stride,
int idx)
288 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
289 setRowWeights(weightVal, stride, idx);
292 std::ostringstream emsg;
293 emsg << __FILE__ <<
"," << __LINE__
294 <<
" error: setWeights not yet supported for"
295 <<
" columns or nonzeros."
297 throw std::runtime_error(emsg.str());
302template <
typename User,
typename UserCoord>
304 const scalar_t *weightVal,
int stride,
int idx)
306 typedef StridedData<lno_t,scalar_t> input_t;
307 if(idx<0 || idx >= nWeightsPerRow_)
309 std::ostringstream emsg;
310 emsg << __FILE__ <<
":" << __LINE__
311 <<
" Invalid row weight index " << idx << std::endl;
312 throw std::runtime_error(emsg.str());
315 size_t nvtx = getLocalNumRows();
316 ArrayRCP<const scalar_t> weightV(weightVal, 0, nvtx*stride,
false);
317 rowWeights_[idx] = input_t(weightV, stride);
321template <
typename User,
typename UserCoord>
325 if (this->getPrimaryEntityType() ==
MATRIX_ROW)
326 setRowWeightIsNumberOfNonZeros(idx);
329 std::ostringstream emsg;
330 emsg << __FILE__ <<
"," << __LINE__
331 <<
" error: setWeightIsNumberOfNonZeros not yet supported for"
332 <<
" columns" << std::endl;
333 throw std::runtime_error(emsg.str());
338template <
typename User,
typename UserCoord>
342 if(idx<0 || idx >= nWeightsPerRow_)
344 std::ostringstream emsg;
345 emsg << __FILE__ <<
":" << __LINE__
346 <<
" Invalid row weight index " << idx << std::endl;
347 throw std::runtime_error(emsg.str());
351 numNzWeight_[idx] =
true;
355template <
typename User,
typename UserCoord>
356 template <
typename Adapter>
358 const User &in, User *&out,
359 const PartitioningSolution<Adapter> &solution)
const
363 ArrayRCP<gno_t> importList;
366 TpetraRowMatrixAdapter<User,UserCoord> >
367 (solution,
this, importList);
372 RCP<User> outPtr = doMigration(in, numNewRows, importList.getRawPtr());
378template <
typename User,
typename UserCoord>
379 template <
typename Adapter>
381 const User &in, RCP<User> &out,
382 const PartitioningSolution<Adapter> &solution)
const
386 ArrayRCP<gno_t> importList;
389 TpetraRowMatrixAdapter<User,UserCoord> >
390 (solution,
this, importList);
395 out = doMigration(in, numNewRows, importList.getRawPtr());
400template <
typename User,
typename UserCoord>
404 const gno_t *myNewRows
407 typedef Tpetra::Map<lno_t, gno_t, node_t>
map_t;
408 typedef Tpetra::CrsMatrix<scalar_t, lno_t, gno_t, node_t> tcrsmatrix_t;
419 const tcrsmatrix_t *pCrsMatrix =
dynamic_cast<const tcrsmatrix_t *
>(&from);
422 throw std::logic_error(
"TpetraRowMatrixAdapter cannot migrate data for "
423 "your RowMatrix; it can migrate data only for "
424 "Tpetra::CrsMatrix. "
425 "You can inherit from TpetraRowMatrixAdapter and "
426 "implement migration for your RowMatrix.");
430 const RCP<const map_t> &smap = from.getRowMap();
431 gno_t numGlobalRows = smap->getGlobalNumElements();
432 gno_t base = smap->getMinAllGlobalIndex();
435 ArrayView<const gno_t> rowList(myNewRows, numLocalRows);
436 const RCP<const Teuchos::Comm<int> > &comm = from.getComm();
437 RCP<const map_t> tmap = rcp(
new map_t(numGlobalRows, rowList, base, comm));
440 Tpetra::Import<lno_t, gno_t, node_t> importer(smap, tmap);
460 int oldNumElts = smap->getLocalNumElements();
461 int newNumElts = numLocalRows;
464 typedef Tpetra::Vector<scalar_t, lno_t, gno_t, node_t> vector_t;
465 vector_t numOld(smap);
466 vector_t numNew(tmap);
467 for (
int lid=0; lid < oldNumElts; lid++){
468 numOld.replaceGlobalValue(smap->getGlobalElement(lid),
469 scalar_t(from.getNumEntriesInLocalRow(lid)));
471 numNew.doImport(numOld, importer, Tpetra::INSERT);
474 ArrayRCP<size_t> nnz(newNumElts);
476 ArrayRCP<scalar_t> ptr = numNew.getDataNonConst(0);
477 for (
int lid=0; lid < newNumElts; lid++){
478 nnz[lid] =
static_cast<size_t>(ptr[lid]);
482 RCP<tcrsmatrix_t> M =
483 rcp(
new tcrsmatrix_t(tmap, nnz()));
485 M->doImport(from, importer, Tpetra::INSERT);
489 return Teuchos::rcp_dynamic_cast<User>(M);
Zoltan2::BasicUserTypes< zscalar_t, zlno_t, zgno_t > user_t
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Defines the MatrixAdapter interface.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
InputTraits< User >::node_t node_t
InputTraits< User >::offset_t offset_t
InputTraits< User >::part_t part_t
InputTraits< User >::scalar_t scalar_t
InputTraits< User >::lno_t lno_t
InputTraits< User >::gno_t gno_t
MatrixAdapter defines the adapter interface for matrices.
Provides access for Zoltan2 to Tpetra::RowMatrix data.
bool CRSViewAvailable() const
Indicates whether the MatrixAdapter implements a view of the matrix in compressed sparse row (CRS) fo...
size_t getLocalNumEntries() const
Returns the number of nonzeros on this process.
size_t getLocalNumColumns() const
Returns the number of columns on this process.
size_t getLocalNumRows() const
Returns the number of rows on this process.
~TpetraRowMatrixAdapter()
Destructor.
TpetraRowMatrixAdapter(const RCP< const User > &inmatrix, int nWeightsPerRow=0)
Constructor.
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds, ArrayRCP< const scalar_t > &values) const
void getRowWeightsView(const scalar_t *&weights, int &stride, int idx=0) const
Provide a pointer to the row weights, if any.
void setRowWeightIsNumberOfNonZeros(int idx)
Specify an index for which the row weight should be the global number of nonzeros in the row.
void setWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each entity of the primaryEntityType.
void getCRSView(ArrayRCP< const offset_t > &offsets, ArrayRCP< const gno_t > &colIds) const
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
void getRowIDsView(const gno_t *&rowIds) const
bool useNumNonzerosAsRowWeight(int idx) const
Indicate whether row weight with index idx should be the global number of nonzeros in the row.
int getNumWeightsPerRow() const
Returns the number of weights per row (0 or greater). Row weights may be used when partitioning matri...
void setRowWeights(const scalar_t *weightVal, int stride, int idx=0)
Specify a weight for each row.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
map_t::local_ordinal_type lno_t
map_t::global_ordinal_type gno_t
Created by mbenlioglu on Aug 31, 2020.
size_t getImportList(const PartitioningSolution< SolutionAdapter > &solution, const DataAdapter *const data, ArrayRCP< typename DataAdapter::gno_t > &imports)
From a PartitioningSolution, get a list of IDs to be imported. Assumes part numbers in PartitioningSo...