33#ifndef GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
34#define GKO_PUBLIC_CORE_SOLVER_SOLVER_BASE_HPP_
42#include <ginkgo/core/base/lin_op.hpp>
43#include <ginkgo/core/base/math.hpp>
44#include <ginkgo/core/log/logger.hpp>
45#include <ginkgo/core/matrix/dense.hpp>
46#include <ginkgo/core/matrix/identity.hpp>
47#include <ginkgo/core/solver/workspace.hpp>
48#include <ginkgo/core/stop/combined.hpp>
49#include <ginkgo/core/stop/criterion.hpp>
52GKO_BEGIN_DISABLE_DEPRECATION_WARNINGS
96 friend class multigrid::detail::MultigridState;
111 virtual void apply_with_initial_guess(
const LinOp* b,
LinOp* x,
117 apply_with_initial_guess(b.get(), x.get(),
guess);
132 virtual void apply_with_initial_guess(
const LinOp* alpha,
const LinOp* b,
143 apply_with_initial_guess(alpha.get(), b.get(), beta.get(), x.get(),
189template <
typename DerivedType>
192 friend class multigrid::detail::MultigridState;
203 void apply_with_initial_guess(
const LinOp* b,
LinOp* x,
207 auto exec = self()->get_executor();
208 GKO_ASSERT_CONFORMANT(self(), b);
209 GKO_ASSERT_EQUAL_ROWS(self(), x);
210 GKO_ASSERT_EQUAL_COLS(b, x);
221 void apply_with_initial_guess(
const LinOp* alpha,
const LinOp* b,
226 self(), alpha, b, beta, x);
227 auto exec = self()->get_executor();
228 GKO_ASSERT_CONFORMANT(self(), b);
229 GKO_ASSERT_EQUAL_ROWS(self(), x);
230 GKO_ASSERT_EQUAL_COLS(b, x);
231 GKO_ASSERT_EQUAL_DIMENSIONS(alpha,
dim<2>(1, 1));
232 GKO_ASSERT_EQUAL_DIMENSIONS(beta,
dim<2>(1, 1));
233 this->apply_with_initial_guess_impl(
239 self(), alpha, b, beta, x);
247 virtual void apply_with_initial_guess_impl(
254 virtual void apply_with_initial_guess_impl(
266template <
typename Solver>
269 static int num_vectors(
const Solver&) {
return 0; }
271 static int num_arrays(
const Solver&) {
return 0; }
273 static std::vector<std::string> op_names(
const Solver&) {
return {}; }
275 static std::vector<std::string> array_names(
const Solver&) {
return {}; }
277 static std::vector<int> scalars(
const Solver&) {
return {}; }
279 static std::vector<int> vectors(
const Solver&) {
return {}; }
298template <
typename DerivedType>
309 auto exec = self()->get_executor();
326 if (&
other !=
this) {
339 if (&
other !=
this) {
341 other.set_preconditioner(
nullptr);
367 *
this = std::move(
other);
391class SolverBaseLinOp {
393 SolverBaseLinOp(std::shared_ptr<const Executor> exec)
394 : workspace_{std::
move(exec)}
397 virtual ~SolverBaseLinOp() =
default;
404 std::shared_ptr<const LinOp> get_system_matrix()
const
406 return system_matrix_;
409 const LinOp* get_workspace_op(
int vector_id)
const
414 virtual int get_num_workspace_ops()
const {
return 0; }
416 virtual std::vector<std::string> get_workspace_op_names()
const
425 virtual std::vector<int> get_workspace_scalars()
const {
return {}; }
431 virtual std::vector<int> get_workspace_vectors()
const {
return {}; }
434 void set_system_matrix_base(std::shared_ptr<const LinOp> system_matrix)
436 system_matrix_ = std::move(system_matrix);
439 void set_workspace_size(
int num_operators,
int num_arrays)
const
444 template <
typename LinOpType>
450 return LinOpType::create(this->workspace_.get_executor(), size);
455 template <
typename LinOpType>
460 vector_id, [&] {
return LinOpType::create_with_config_of(vec); },
461 typeid(*vec), vec->get_size(), vec->get_stride());
464 template <
typename LinOpType>
472 return LinOpType::create_with_type_of(
473 vec, workspace_.get_executor(), size, size[1]);
475 typeid(*vec), size, size[1]);
478 template <
typename LinOpType>
487 return LinOpType::create_with_type_of(
494 template <
typename ValueType>
495 matrix::Dense<ValueType>* create_workspace_scalar(
int vector_id,
501 return matrix::Dense<ValueType>::create(
502 workspace_.get_executor(), dim<2>{1, size});
504 typeid(matrix::Dense<ValueType>),
gko::dim<2>{1, size}, size);
507 template <
typename ValueType>
514 template <
typename ValueType>
521 mutable detail::workspace workspace_;
523 std::shared_ptr<const LinOp> system_matrix_;
530template <
typename MatrixType>
533 GKO_DEPRECATED(
"This class will be replaced by the template-less detail::SolverBaseLinOp in a future release")
SolverBase
535 :
public detail::SolverBaseLinOp {
537 using detail::SolverBaseLinOp::SolverBaseLinOp;
548 return std::dynamic_pointer_cast<const MatrixType>(
549 SolverBaseLinOp::get_system_matrix());
553 void set_system_matrix_base(std::shared_ptr<const MatrixType> system_matrix)
555 SolverBaseLinOp::set_system_matrix_base(std::move(system_matrix));
569template <
typename DerivedType,
typename MatrixType = LinOp>
578 if (&
other !=
this) {
579 set_system_matrix(
other.get_system_matrix());
590 if (&
other !=
this) {
591 set_system_matrix(
other.get_system_matrix());
592 other.set_system_matrix(
nullptr);
599 EnableSolverBase(std::shared_ptr<const MatrixType> system_matrix)
600 : SolverBase<
MatrixType>{self()->get_executor()}
602 set_system_matrix(std::move(system_matrix));
621 *
this = std::move(
other);
624 int get_num_workspace_ops()
const override
627 return traits::num_vectors(*self());
630 std::vector<std::string> get_workspace_op_names()
const override
633 return traits::op_names(*self());
643 return traits::scalars(*self());
653 return traits::vectors(*self());
659 auto exec = self()->get_executor();
670 void setup_workspace()
const
673 this->set_workspace_size(traits::num_vectors(*self()),
674 traits::num_arrays(*self()));
703 return stop_factory_;
718 std::shared_ptr<const stop::CriterionFactory> stop_factory_;
731template <
typename DerivedType>
740 if (&
other !=
this) {
753 if (&
other !=
this) {
755 other.set_stop_criterion_factory(
nullptr);
763 std::shared_ptr<const stop::CriterionFactory>
stop_factory)
779 *
this = std::move(
other);
785 auto exec = self()->get_executor();
812template <
typename ValueType,
typename DerivedType>
821 std::shared_ptr<const LinOp> system_matrix,
822 std::shared_ptr<const stop::CriterionFactory>
stop_factory,
823 std::shared_ptr<const LinOp> preconditioner)
829 template <
typename FactoryParameters>
831 std::shared_ptr<const LinOp> system_matrix,
835 generate_preconditioner(system_matrix,
params)}
839 template <
typename FactoryParameters>
840 static std::shared_ptr<const LinOp> generate_preconditioner(
841 std::shared_ptr<const LinOp> system_matrix,
844 if (
params.generated_preconditioner) {
845 return params.generated_preconditioner;
846 }
else if (
params.preconditioner) {
847 return params.preconditioner->generate(system_matrix);
850 system_matrix->get_executor(), system_matrix->get_size());
856template <
typename Parameters,
typename Factory>
862 std::vector<std::shared_ptr<const stop::CriterionFactory>>
867template <
typename Parameters,
typename Factory>
874 std::shared_ptr<const LinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
890GKO_END_DISABLE_DEPRECATION_WARNINGS
Definition lin_op.hpp:146
A LinOp implementing this interface can be preconditioned.
Definition lin_op.hpp:711
virtual void set_preconditioner(std::shared_ptr< const LinOp > new_precond)
Sets the preconditioner operator used by the Preconditionable.
Definition lin_op.hpp:731
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:239
This class is a utility which efficiently implements the identity matrix (a linear operator which map...
Definition identity.hpp:65
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:71
ApplyWithInitialGuess provides a way to give the input guess for apply function.
Definition solver_base.hpp:94
EnableApplyWithInitialGuess providing default operation for ApplyWithInitialGuess with correct valida...
Definition solver_base.hpp:190
A LinOp deriving from this CRTP class stores a stopping criterion factory and allows applying with a ...
Definition solver_base.hpp:732
EnableIterativeBase & operator=(EnableIterativeBase &&other)
Moves the provided stopping criterion, clones it onto this executor if executors don't match.
Definition solver_base.hpp:751
void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory) override
Sets the stopping criterion of the solver.
Definition solver_base.hpp:782
EnableIterativeBase(EnableIterativeBase &&other)
Moves the provided stopping criterion.
Definition solver_base.hpp:777
EnableIterativeBase(const EnableIterativeBase &other)
Creates a shallow copy of the provided stopping criterion.
Definition solver_base.hpp:771
EnableIterativeBase & operator=(const EnableIterativeBase &other)
Creates a shallow copy of the provided stopping criterion, clones it onto this executor if executors ...
Definition solver_base.hpp:738
Mixin providing default operation for Preconditionable with correct value semantics.
Definition solver_base.hpp:299
EnablePreconditionable(const EnablePreconditionable &other)
Creates a shallow copy of the provided preconditioner.
Definition solver_base.hpp:356
EnablePreconditionable & operator=(EnablePreconditionable &&other)
Moves the provided preconditioner, clones it onto this executor if executors don't match.
Definition solver_base.hpp:337
EnablePreconditionable(EnablePreconditionable &&other)
Moves the provided preconditioner.
Definition solver_base.hpp:365
EnablePreconditionable & operator=(const EnablePreconditionable &other)
Creates a shallow copy of the provided preconditioner, clones it onto this executor if executors don'...
Definition solver_base.hpp:324
void set_preconditioner(std::shared_ptr< const LinOp > new_precond) override
Sets the preconditioner operator used by the Preconditionable.
Definition solver_base.hpp:307
A LinOp implementing this interface stores a system matrix and stopping criterion factory.
Definition solver_base.hpp:816
A LinOp deriving from this CRTP class stores a system matrix.
Definition solver_base.hpp:570
EnableSolverBase(EnableSolverBase &&other)
Moves the provided system matrix.
Definition solver_base.hpp:618
std::vector< int > get_workspace_vectors() const override
Returns the IDs of all vectors (workspace vectors with system dimension-dependent size,...
Definition solver_base.hpp:650
std::vector< int > get_workspace_scalars() const override
Returns the IDs of all scalars (workspace vectors with system dimension-independent size,...
Definition solver_base.hpp:640
EnableSolverBase(const EnableSolverBase &other)
Creates a shallow copy of the provided system matrix.
Definition solver_base.hpp:608
EnableSolverBase & operator=(EnableSolverBase &&other)
Moves the provided system matrix, clones it onto this executor if executors don't match.
Definition solver_base.hpp:588
EnableSolverBase & operator=(const EnableSolverBase &other)
Creates a shallow copy of the provided system matrix, clones it onto this executor if executors don't...
Definition solver_base.hpp:576
A LinOp implementing this interface stores a stopping criterion factory.
Definition solver_base.hpp:693
std::shared_ptr< const stop::CriterionFactory > get_stop_criterion_factory() const
Gets the stopping criterion factory of the solver.
Definition solver_base.hpp:700
virtual void set_stop_criterion_factory(std::shared_ptr< const stop::CriterionFactory > new_stop_factory)
Sets the stopping criterion of the solver.
Definition solver_base.hpp:711
Definition solver_base.hpp:535
std::shared_ptr< const MatrixType > get_system_matrix() const
Returns the system matrix, with its concrete type, used by the solver.
Definition solver_base.hpp:546
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:473
std::shared_ptr< const CriterionFactory > combine(FactoryContainer &&factories)
Combines multiple criterion factories into a single combined criterion factory.
Definition combined.hpp:138
initial_guess_mode
Give a initial guess mode about the input of the apply method.
Definition solver_base.hpp:62
@ provided
the input is provided
@ rhs
the input is right hand side
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:203
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition temporary_clone.hpp:207
A type representing the dimensions of a multidimensional object.
Definition dim.hpp:55
Definition solver_base.hpp:858
std::vector< std::shared_ptr< const stop::CriterionFactory > > criteria
Stopping criteria to be used by the solver.
Definition solver_base.hpp:863
Definition solver_base.hpp:869
std::shared_ptr< const LinOp > generated_preconditioner
Already generated preconditioner.
Definition solver_base.hpp:882
std::shared_ptr< const LinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition solver_base.hpp:875
Traits class providing information on the type and location of workspace vectors inside a solver.
Definition solver_base.hpp:267