50#ifndef _ZOLTAN2_TPETRAROWGRAPHADAPTER_HPP_
51#define _ZOLTAN2_TPETRAROWGRAPHADAPTER_HPP_
56#include <Tpetra_RowGraph.hpp>
81template <
typename User,
typename UserCoord=User>
86#ifndef DOXYGEN_SHOULD_SKIP_THIS
94 typedef UserCoord userCoord_t;
111 int nVtxWeights=0,
int nEdgeWeights=0);
198 ids = graph_->getRowMap()->getLocalElementList().getRawPtr();
205 offsets = offs_.getRawPtr();
214 if(idx<0 || idx >= nWeightsPerVertex_)
216 std::ostringstream emsg;
217 emsg << __FILE__ <<
":" << __LINE__
218 <<
" Invalid vertex weight index " << idx << std::endl;
219 throw std::runtime_error(emsg.str());
223 vertexWeights_[idx].getStridedList(length,
weights, stride);
232 if(idx<0 || idx >= nWeightsPerEdge_)
234 std::ostringstream emsg;
235 emsg << __FILE__ <<
":" << __LINE__
236 <<
" Invalid edge weight index " << idx << std::endl;
237 throw std::runtime_error(emsg.str());
241 edgeWeights_[idx].getStridedList(length,
weights, stride);
245 template <
typename Adapter>
247 const PartitioningSolution<Adapter> &solution)
const;
249 template <
typename Adapter>
251 const PartitioningSolution<Adapter> &solution)
const;
255 RCP<const User> graph_;
257 ArrayRCP<const offset_t> offs_;
258 ArrayRCP<const gno_t> adjids_;
260 int nWeightsPerVertex_;
261 ArrayRCP<StridedData<lno_t, scalar_t> > vertexWeights_;
262 ArrayRCP<bool> vertexDegreeWeight_;
264 int nWeightsPerEdge_;
265 ArrayRCP<StridedData<lno_t, scalar_t> > edgeWeights_;
268 ArrayRCP<StridedData<lno_t, scalar_t> > coords_;
270 RCP<User> doMigration(
const User &from,
size_t numLocalRows,
271 const gno_t *myNewRows)
const;
278template <
typename User,
typename UserCoord>
280 const RCP<const User> &ingraph,
int nVtxWgts,
int nEdgeWgts):
281 graph_(ingraph), offs_(),
283 nWeightsPerVertex_(nVtxWgts), vertexWeights_(), vertexDegreeWeight_(),
284 nWeightsPerEdge_(nEdgeWgts), edgeWeights_(),
285 coordinateDim_(0), coords_()
287 typedef StridedData<lno_t,scalar_t> input_t;
289 size_t nvtx = graph_->getLocalNumRows();
290 size_t nedges = graph_->getLocalNumEntries();
291 size_t maxnumentries =
292 graph_->getLocalMaxNumRowEntries();
302 std::cerr <<
"Error: " << __FILE__ <<
", " << __LINE__<< std::endl;
303 std::cerr << n <<
" objects" << std::endl;
304 throw std::bad_alloc();
307 gno_t *adjids = NULL;
310 adjids =
new gno_t [nedges];
314 std::cerr <<
"Error: " << __FILE__ <<
", " << __LINE__<< std::endl;
315 std::cerr << nedges <<
" objects" << std::endl;
316 throw std::bad_alloc();
320 typename User::nonconst_local_inds_host_view_type nbors(
"nbors", maxnumentries);
323 for (
size_t v=0; v < nvtx; v++){
324 graph_->getLocalRowCopy(v, nbors, nedges);
325 offs[v+1] = offs[v] + nedges;
326 for (
offset_t e=offs[v], i=0; e < offs[v+1]; e++) {
327 adjids[e] = graph_->getColMap()->getGlobalElement(nbors[i++]);
331 offs_ = arcp(offs, 0, n,
true);
332 adjids_ = arcp(adjids, 0, nedges,
true);
334 if (nWeightsPerVertex_ > 0) {
336 arcp(
new input_t[nWeightsPerVertex_], 0, nWeightsPerVertex_,
true);
337 vertexDegreeWeight_ =
338 arcp(
new bool[nWeightsPerVertex_], 0, nWeightsPerVertex_,
true);
339 for (
int i=0; i < nWeightsPerVertex_; i++)
340 vertexDegreeWeight_[i] =
false;
343 if (nWeightsPerEdge_ > 0)
344 edgeWeights_ = arcp(
new input_t[nWeightsPerEdge_], 0, nWeightsPerEdge_,
true);
348template <
typename User,
typename UserCoord>
350 const scalar_t *weightVal,
int stride,
int idx)
353 setVertexWeights(weightVal, stride, idx);
355 setEdgeWeights(weightVal, stride, idx);
359template <
typename User,
typename UserCoord>
361 const scalar_t *weightVal,
int stride,
int idx)
363 typedef StridedData<lno_t,scalar_t> input_t;
364 if(idx<0 || idx >= nWeightsPerVertex_)
366 std::ostringstream emsg;
367 emsg << __FILE__ <<
":" << __LINE__
368 <<
" Invalid vertex weight index " << idx << std::endl;
369 throw std::runtime_error(emsg.str());
372 size_t nvtx = getLocalNumVertices();
373 ArrayRCP<const scalar_t> weightV(weightVal, 0, nvtx*stride,
false);
374 vertexWeights_[idx] = input_t(weightV, stride);
378template <
typename User,
typename UserCoord>
383 setVertexWeightIsDegree(idx);
385 std::ostringstream emsg;
386 emsg << __FILE__ <<
"," << __LINE__
387 <<
" error: setWeightIsNumberOfNonZeros is supported only for"
388 <<
" vertices" << std::endl;
389 throw std::runtime_error(emsg.str());
394template <
typename User,
typename UserCoord>
398 if(idx<0 || idx >= nWeightsPerVertex_)
400 std::ostringstream emsg;
401 emsg << __FILE__ <<
":" << __LINE__
402 <<
" Invalid vertex weight index " << idx << std::endl;
403 throw std::runtime_error(emsg.str());
406 vertexDegreeWeight_[idx] =
true;
410template <
typename User,
typename UserCoord>
412 const scalar_t *weightVal,
int stride,
int idx)
414 typedef StridedData<lno_t,scalar_t> input_t;
416 if(idx<0 || idx >= nWeightsPerEdge_)
418 std::ostringstream emsg;
419 emsg << __FILE__ <<
":" << __LINE__
420 <<
" Invalid edge weight index " << idx << std::endl;
421 throw std::runtime_error(emsg.str());
424 size_t nedges = getLocalNumEdges();
425 ArrayRCP<const scalar_t> weightV(weightVal, 0, nedges*stride,
false);
426 edgeWeights_[idx] = input_t(weightV, stride);
430template <
typename User,
typename UserCoord>
431 template<
typename Adapter>
433 const User &in, User *&out,
434 const PartitioningSolution<Adapter> &solution)
const
438 ArrayRCP<gno_t> importList;
441 TpetraRowGraphAdapter<User,UserCoord> >
442 (solution,
this, importList);
447 RCP<User> outPtr = doMigration(in, numNewVtx, importList.getRawPtr());
453template <
typename User,
typename UserCoord>
454 template<
typename Adapter>
456 const User &in, RCP<User> &out,
457 const PartitioningSolution<Adapter> &solution)
const
461 ArrayRCP<gno_t> importList;
464 TpetraRowGraphAdapter<User,UserCoord> >
465 (solution,
this, importList);
470 out = doMigration(in, numNewVtx, importList.getRawPtr());
475template <
typename User,
typename UserCoord>
479 const gno_t *myNewRows
482 typedef Tpetra::Map<lno_t, gno_t, node_t>
map_t;
483 typedef Tpetra::CrsGraph<lno_t, gno_t, node_t> tcrsgraph_t;
494 const tcrsgraph_t *pCrsGraphSrc =
dynamic_cast<const tcrsgraph_t *
>(&from);
497 throw std::logic_error(
"TpetraRowGraphAdapter cannot migrate data for "
498 "your RowGraph; it can migrate data only for "
500 "You can inherit from TpetraRowGraphAdapter and "
501 "implement migration for your RowGraph.");
505 const RCP<const map_t> &smap = from.getRowMap();
506 int oldNumElts = smap->getLocalNumElements();
507 gno_t numGlobalRows = smap->getGlobalNumElements();
508 gno_t base = smap->getMinAllGlobalIndex();
511 ArrayView<const gno_t> rowList(myNewRows, numLocalRows);
512 const RCP<const Teuchos::Comm<int> > &comm = from.getComm();
513 RCP<const map_t> tmap = rcp(
new map_t(numGlobalRows, rowList, base, comm));
516 Tpetra::Import<lno_t, gno_t, node_t> importer(smap, tmap);
519 typedef Tpetra::Vector<gno_t, lno_t, gno_t, node_t> vector_t;
520 vector_t numOld(smap);
521 vector_t numNew(tmap);
522 for (
int lid=0; lid < oldNumElts; lid++){
523 numOld.replaceGlobalValue(smap->getGlobalElement(lid),
524 from.getNumEntriesInLocalRow(lid));
526 numNew.doImport(numOld, importer, Tpetra::INSERT);
528 size_t numElts = tmap->getLocalNumElements();
529 ArrayRCP<const gno_t> nnz;
531 nnz = numNew.getData(0);
533 ArrayRCP<const size_t> nnz_size_t;
535 if (numElts &&
sizeof(
gno_t) !=
sizeof(
size_t)){
536 size_t *vals =
new size_t [numElts];
537 nnz_size_t = arcp(vals, 0, numElts,
true);
538 for (
size_t i=0; i < numElts; i++){
539 vals[i] =
static_cast<size_t>(nnz[i]);
543 nnz_size_t = arcp_reinterpret_cast<const size_t>(nnz);
548 rcp(
new tcrsgraph_t(tmap, nnz_size_t()));
550 G->doImport(*pCrsGraphSrc, importer, Tpetra::INSERT);
552 return Teuchos::rcp_dynamic_cast<User>(G);
Zoltan2::BasicUserTypes< zscalar_t, zlno_t, zgno_t > user_t
#define Z2_FORWARD_EXCEPTIONS
Forward an exception back through call stack.
Defines the GraphAdapter interface.
Helper functions for Partitioning Problems.
This file defines the StridedData class.
InputTraits< User >::node_t node_t
InputTraits< User >::offset_t offset_t
InputTraits< User >::part_t part_t
InputTraits< User >::scalar_t scalar_t
InputTraits< User >::lno_t lno_t
InputTraits< User >::gno_t gno_t
GraphAdapter defines the interface for graph-based user data.
Provides access for Zoltan2 to Tpetra::RowGraph data.
size_t getLocalNumEdges() const
Returns the number of edges on this process.
void getEdgeWeightsView(const scalar_t *&weights, int &stride, int idx) const
Provide a pointer to the edge weights, if any.
size_t getLocalNumVertices() const
Returns the number of vertices on this process.
void getVertexIDsView(const gno_t *&ids) const
bool useDegreeAsVertexWeight(int idx) const
Indicate whether vertex weight with index idx should be the global degree of the vertex.
void getEdgesView(const offset_t *&offsets, const gno_t *&adjIds) const
void setEdgeWeights(const scalar_t *val, int stride, int idx)
Provide a pointer to edge weights.
void setWeightIsDegree(int idx)
Specify an index for which the weight should be the degree of the entity.
~TpetraRowGraphAdapter()
Destructor.
void applyPartitioningSolution(const User &in, User *&out, const PartitioningSolution< Adapter > &solution) const
int getNumWeightsPerVertex() const
Returns the number (0 or greater) of weights per vertex.
void setWeights(const scalar_t *val, int stride, int idx)
Provide a pointer to weights for the primary entity type.
void setVertexWeights(const scalar_t *val, int stride, int idx)
Provide a pointer to vertex weights.
TpetraRowGraphAdapter(const RCP< const User > &ingraph, int nVtxWeights=0, int nEdgeWeights=0)
Constructor for graph with no weights or coordinates.
void setVertexWeightIsDegree(int idx)
Specify an index for which the vertex weight should be the degree of the vertex.
void getVertexWeightsView(const scalar_t *&weights, int &stride, int idx) const
Provide a pointer to the vertex weights, if any.
int getNumWeightsPerEdge() const
Returns the number (0 or greater) of edge weights.
map_t::global_ordinal_type gno_t
Created by mbenlioglu on Aug 31, 2020.
size_t getImportList(const PartitioningSolution< SolutionAdapter > &solution, const DataAdapter *const data, ArrayRCP< typename DataAdapter::gno_t > &imports)
From a PartitioningSolution, get a list of IDs to be imported. Assumes part numbers in PartitioningSo...