Ginkgo Generated from branch based on master. Ginkgo version 1.7.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
batch_dense.hpp
1/*******************************<GINKGO LICENSE>******************************
2Copyright (c) 2017-2023, the Ginkgo authors
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions
7are met:
8
91. Redistributions of source code must retain the above copyright
10notice, this list of conditions and the following disclaimer.
11
122. Redistributions in binary form must reproduce the above copyright
13notice, this list of conditions and the following disclaimer in the
14documentation and/or other materials provided with the distribution.
15
163. Neither the name of the copyright holder nor the names of its
17contributors may be used to endorse or promote products derived from
18this software without specific prior written permission.
19
20THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31******************************<GINKGO LICENSE>*******************************/
32
33#ifndef GKO_PUBLIC_CORE_MATRIX_BATCH_DENSE_HPP_
34#define GKO_PUBLIC_CORE_MATRIX_BATCH_DENSE_HPP_
35
36
37#include <initializer_list>
38#include <vector>
39
40
41#include <ginkgo/core/base/array.hpp>
42#include <ginkgo/core/base/batch_lin_op.hpp>
43#include <ginkgo/core/base/batch_multi_vector.hpp>
44#include <ginkgo/core/base/executor.hpp>
45#include <ginkgo/core/base/mtx_io.hpp>
46#include <ginkgo/core/base/range_accessors.hpp>
47#include <ginkgo/core/base/types.hpp>
48#include <ginkgo/core/base/utils.hpp>
49#include <ginkgo/core/matrix/dense.hpp>
50
51
52namespace gko {
53namespace batch {
54namespace matrix {
55
56
76template <typename ValueType = default_precision>
77class Dense final : public EnableBatchLinOp<Dense<ValueType>>,
78 public EnableCreateMethod<Dense<ValueType>>,
79 public ConvertibleTo<Dense<next_precision<ValueType>>> {
80 friend class EnableCreateMethod<Dense>;
82 friend class Dense<to_complex<ValueType>>;
83 friend class Dense<next_precision<ValueType>>;
84
85public:
86 using EnableBatchLinOp<Dense>::convert_to;
87 using EnableBatchLinOp<Dense>::move_to;
88
89 using value_type = ValueType;
90 using index_type = int32;
93 using absolute_type = remove_complex<Dense>;
94 using complex_type = to_complex<Dense>;
95
96 void convert_to(Dense<next_precision<ValueType>>* result) const override;
97
98 void move_to(Dense<next_precision<ValueType>>* result) override;
99
110 std::unique_ptr<unbatch_type> create_view_for_item(size_type item_id);
111
115 std::unique_ptr<const unbatch_type> create_const_view_for_item(
116 size_type item_id) const;
117
126 {
127 GKO_ASSERT(batch_id < this->get_num_batch_items());
128 return batch_id * this->get_common_size()[0] *
129 this->get_common_size()[1];
130 }
131
137 value_type* get_values() noexcept { return values_.get_data(); }
138
146 const value_type* get_const_values() const noexcept
147 {
148 return values_.get_const_data();
149 }
150
162 value_type& at(size_type batch_id, size_type row, size_type col)
163 {
164 GKO_ASSERT(batch_id < this->get_num_batch_items());
165 return values_.get_data()[linearize_index(batch_id, row, col)];
166 }
167
171 value_type at(size_type batch_id, size_type row, size_type col) const
172 {
173 GKO_ASSERT(batch_id < this->get_num_batch_items());
174 return values_.get_const_data()[linearize_index(batch_id, row, col)];
175 }
176
191 ValueType& at(size_type batch_id, size_type idx) noexcept
192 {
193 GKO_ASSERT(batch_id < this->get_num_batch_items());
194 return values_.get_data()[linearize_index(batch_id, idx)];
195 }
196
200 ValueType at(size_type batch_id, size_type idx) const noexcept
201 {
202 GKO_ASSERT(batch_id < this->get_num_batch_items());
203 return values_.get_const_data()[linearize_index(batch_id, idx)];
204 }
205
215 {
216 GKO_ASSERT(batch_id < this->get_num_batch_items());
217 return values_.get_data() + this->get_cumulative_offset(batch_id);
218 }
219
227 const value_type* get_const_values_for_item(
228 size_type batch_id) const noexcept
229 {
230 GKO_ASSERT(batch_id < this->get_num_batch_items());
231 return values_.get_const_data() + this->get_cumulative_offset(batch_id);
232 }
233
242 {
243 return values_.get_num_elems();
244 }
245
258 static std::unique_ptr<const Dense<value_type>> create_const(
259 std::shared_ptr<const Executor> exec, const batch_dim<2>& sizes,
260 gko::detail::const_array_view<ValueType>&& values);
261
271
286
292
302
303private:
304 inline size_type compute_num_elems(const batch_dim<2>& size)
305 {
306 return size.get_num_batch_items() * size.get_common_size()[0] *
307 size.get_common_size()[1];
308 }
309
316 Dense(std::shared_ptr<const Executor> exec,
317 const batch_dim<2>& size = batch_dim<2>{});
318
333 template <typename ValuesArray>
334 Dense(std::shared_ptr<const Executor> exec, const batch_dim<2>& size,
335 ValuesArray&& values)
336 : EnableBatchLinOp<Dense>(exec, size),
337 values_{exec, std::forward<ValuesArray>(values)}
338 {
339 // Ensure that the values array has the correct size
340 auto num_elems = compute_num_elems(size);
341 GKO_ENSURE_IN_BOUNDS(num_elems, values_.get_num_elems() + 1);
342 }
343
344 void apply_impl(const MultiVector<value_type>* b,
345 MultiVector<value_type>* x) const;
346
347 void apply_impl(const MultiVector<value_type>* alpha,
349 const MultiVector<value_type>* beta,
350 MultiVector<value_type>* x) const;
351
352 size_type linearize_index(size_type batch, size_type row,
353 size_type col) const noexcept
354 {
355 return this->get_cumulative_offset(batch) +
356 row * this->get_size().get_common_size()[1] + col;
357 }
358
359 size_type linearize_index(size_type batch, size_type idx) const noexcept
360 {
361 return linearize_index(batch, idx / this->get_common_size()[1],
362 idx % this->get_common_size()[1]);
363 }
364
365 array<value_type> values_;
366};
367
368
369} // namespace matrix
370} // namespace batch
371} // namespace gko
372
373
374#endif // GKO_PUBLIC_CORE_MATRIX_BATCH_DENSE_HPP_
ConvertibleTo interface is used to mark that the implementer can be converted to the object of Result...
Definition polymorphic_object.hpp:499
This mixin implements a static create() method on ConcreteType that dynamically allocates the memory,...
Definition polymorphic_object.hpp:776
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:691
value_type * get_data() noexcept
Returns a pointer to the block of memory used to store the elements of the array.
Definition array.hpp:646
const value_type * get_const_data() const noexcept
Returns a constant pointer to the block of memory used to store the elements of the array.
Definition array.hpp:655
size_type get_num_elems() const noexcept
Returns the number of elements in the array.
Definition array.hpp:637
Definition batch_lin_op.hpp:88
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:281
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition batch_multi_vector.hpp:85
Dense is a batch matrix format which explicitly stores all values of the matrix in each of the batche...
Definition batch_dense.hpp:79
std::unique_ptr< unbatch_type > create_view_for_item(size_type item_id)
Creates a mutable view (of gko::matrix::Dense type) of one item of the batch::matrix::Dense<value_typ...
const Dense * apply(ptr_param< const MultiVector< value_type > > b, ptr_param< MultiVector< value_type > > x) const
value_type * get_values() noexcept
Returns a pointer to the array of values of the multi-vector.
Definition batch_dense.hpp:137
Dense * apply(ptr_param< const MultiVector< value_type > > b, ptr_param< MultiVector< value_type > > x)
Apply the matrix to a multi-vector.
value_type at(size_type batch_id, size_type row, size_type col) const
Returns a single element for a particular batch item.
Definition batch_dense.hpp:171
size_type get_cumulative_offset(size_type batch_id) const
Get the cumulative storage size offset.
Definition batch_dense.hpp:125
ValueType & at(size_type batch_id, size_type idx) noexcept
Returns a single element for a particular batch item.
Definition batch_dense.hpp:191
std::unique_ptr< const unbatch_type > create_const_view_for_item(size_type item_id) const
Creates a mutable view (of gko::matrix::Dense type) of one item of the batch::matrix::Dense<value_typ...
value_type * get_values_for_item(size_type batch_id) noexcept
Returns a pointer to the array of values of the matrix for a specific batch item.
Definition batch_dense.hpp:214
value_type & at(size_type batch_id, size_type row, size_type col)
Returns a single element for a particular batch item.
Definition batch_dense.hpp:162
size_type get_num_stored_elements() const noexcept
Returns the number of elements explicitly stored in the batch matrix, cumulative across all the batch...
Definition batch_dense.hpp:241
const Dense * apply(ptr_param< const MultiVector< value_type > > alpha, ptr_param< const MultiVector< value_type > > b, ptr_param< const MultiVector< value_type > > beta, ptr_param< MultiVector< value_type > > x) const
Dense * apply(ptr_param< const MultiVector< value_type > > alpha, ptr_param< const MultiVector< value_type > > b, ptr_param< const MultiVector< value_type > > beta, ptr_param< MultiVector< value_type > > x)
Apply the matrix to a multi-vector with a linear combination of the given input vector.
static std::unique_ptr< const Dense< value_type > > create_const(std::shared_ptr< const Executor > exec, const batch_dim< 2 > &sizes, gko::detail::const_array_view< ValueType > &&values)
Creates a constant (immutable) batch dense matrix from a constant array.
ValueType at(size_type batch_id, size_type idx) const noexcept
Returns a single element for a particular batch item.
Definition batch_dense.hpp:200
const value_type * get_const_values() const noexcept
Returns a pointer to the array of values of the multi-vector.
Definition batch_dense.hpp:146
const value_type * get_const_values_for_item(size_type batch_id) const noexcept
Returns a pointer to the array of values of the matrix for a specific batch item.
Definition batch_dense.hpp:227
Dense is a matrix format which explicitly stores all values of the matrix.
Definition dense.hpp:136
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:71
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:354
std::int32_t int32
32-bit signed integral type.
Definition types.hpp:137
typename detail::next_precision_impl< T >::type next_precision
Obtains the next type in the singly-linked precision list.
Definition math.hpp:490
typename detail::to_complex_s< T >::type to_complex
Obtain the type which adds the complex of complex/scalar type or the template parameter of class by a...
Definition math.hpp:373
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
A type representing the dimensions of a multidimensional batch object.
Definition batch_dim.hpp:56
dim< dimensionality, dimension_type > get_common_size() const
Get the common size of the batch items.
Definition batch_dim.hpp:72
size_type get_num_batch_items() const
Get the number of batch items stored.
Definition batch_dim.hpp:65