Viewing File: /home/ubuntu/combine_ai/combine/lib/python3.10/site-packages/torch/include/ATen/core/IListRef_inl.h
#pragma once
#include <ATen/core/List.h>
#include <ATen/core/Tensor.h>
namespace at {
class Tensor;
class OptionalTensorRef;
}
namespace c10 {
namespace detail {
/*
* Specializations of `IListRefTagImplBase` that implement the default
* implementation for `IListRefTag::Unboxed`.
*/
template <typename T, typename ListElemT>
class IListRefTagImplBase<IListRefTag::Unboxed, T, ListElemT> {
public:
using elem_type = ListElemT;
using list_type = ArrayRef<elem_type>;
/*
* These `unwrap` static methods unwraps the inner containers out
* of `IListRef<T>` (and `IListRefIterator<T>`). They are required when
* the macro `TORCH_ILISTREF_UNWRAP` is called.
*/
static const list_type& unwrap(const IListRef<T>& ilist) {
return ilist.payload_.unboxed;
}
static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
return it.payload_.unboxed_iterator;
}
static const typename list_type::const_iterator& unwrap(
const IListRefIterator<T>& it) {
return it.payload_.unboxed_iterator;
}
/*
* We have these function (besides the `unwrap`s above) because the
* implementation for both `IListRef::operator[]` and `IListRefIterator::operator*`
* weren't syntatically equal for the existing tags at the time
* (`Unboxed` and `Boxed`).
*/
static IListRefConstRef<T> front(const list_type& lst) {
return lst.front();
}
static IListRefConstRef<T> iterator_get(
const typename list_type::const_iterator& it) {
return *it;
}
};
/*
* Specializations of `IListRefTagImplBase` that implement the default
* implementation for `IListRefTag::Boxed`.
*/
template <typename T, typename ListElemT>
class IListRefTagImplBase<IListRefTag::Boxed, T, ListElemT> {
public:
using elem_type = ListElemT;
using list_type = List<elem_type>;
static const list_type& unwrap(const IListRef<T>& ilist) {
return *ilist.payload_.boxed;
}
static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
return it.payload_.boxed_iterator;
}
static const typename list_type::const_iterator& unwrap(
const IListRefIterator<T>& it) {
return it.payload_.boxed_iterator;
}
static IListRefConstRef<T> front(const list_type& lst) {
return lst[0];
}
static IListRefConstRef<T> iterator_get(
const typename list_type::const_iterator& it) {
return (*it).get().toTensor();
}
};
/*
* Specializations of `IListRefTagImplBase` that implement the default
* implementation for `IListRefTag::Materialized`.
*/
template <typename T>
class IListRefTagImplBase<IListRefTag::Materialized, T, MaterializedIListRefElem<T>> {
public:
using elem_type = MaterializedIListRefElem<T>;
using list_type = MaterializedIListRef<T>;
static const list_type& unwrap(const IListRef<T>& ilist) {
return *ilist.payload_.materialized;
}
static typename list_type::const_iterator& unwrap(IListRefIterator<T>& it) {
return it.payload_.materialized_iterator;
}
static const typename list_type::const_iterator& unwrap(
const IListRefIterator<T>& it) {
return it.payload_.materialized_iterator;
}
static IListRefConstRef<T> front(const list_type& lst) {
return lst[0];
}
static IListRefConstRef<T> iterator_get(
const typename list_type::const_iterator& it) {
return *it;
}
};
/*
* [Note: ITensorListRef]
* Specializations necessary for `IListRef<at::Tensor>` type.
*
* Since the default implementations are usually done with supporting
* `Tensor` in mind, we only have to inherit from the base implementations.
*/
template <>
class IListRefTagImpl<IListRefTag::Unboxed, at::Tensor>
: public IListRefTagImplBase<IListRefTag::Unboxed, at::Tensor> {};
template <>
class IListRefTagImpl<IListRefTag::Boxed, at::Tensor>
: public IListRefTagImplBase<IListRefTag::Boxed, at::Tensor> {};
template <>
class IListRefTagImpl<IListRefTag::Materialized, at::Tensor>
: public IListRefTagImplBase<
IListRefTag::Materialized,
at::Tensor,
MaterializedIListRefElem<at::Tensor>> {};
/*
* [Note: IOptTensorListRef]
* Specializations necessary for `IListRef<at::OptionalTensorRef>` type.
*
* We can't get an `at::OptionalTensorRef` directly from an instance of
* `List<optional<Tensor>>` (the type that corresponds to the boxed world).
*
* So, the default implementation won't help us. Thus, we have to implement
* this method ourselves.
*/
template <>
class IListRefTagImpl<IListRefTag::Unboxed, at::OptionalTensorRef>
: public IListRefTagImplBase<IListRefTag::Unboxed, at::OptionalTensorRef> {};
template <>
class IListRefTagImpl<IListRefTag::Boxed, at::OptionalTensorRef>
: public IListRefTagImplBase<IListRefTag::Boxed, at::OptionalTensorRef, optional<at::Tensor>> {
public:
/*
* Given an instance of the types corresponding to the `Boxed` tag, we override
* the default implementation, so that we can return a `at::OptionalTensorRef`.
*/
static IListRefConstRef<at::OptionalTensorRef> iterator_get(
const typename list_type::const_iterator& it) {
const auto& ivalue = (*it).get();
if (!ivalue.isNone()) {
const auto& tensor = ivalue.toTensor();
return (tensor.defined()) ? tensor : at::OptionalTensorRef{};
}
return {};
}
};
template <>
class IListRefTagImpl<IListRefTag::Materialized, at::OptionalTensorRef>
: public IListRefTagImplBase<
IListRefTag::Materialized,
at::OptionalTensorRef,
MaterializedIListRefElem<at::OptionalTensorRef>> {};
} // namespace detail
} // namespace c10
namespace at {
// [Note: ITensorListRef]
using ITensorListRef = c10::IListRef<at::Tensor>;
using ITensorListRefIterator = c10::IListRefIterator<at::Tensor>;
using MaterializedITensorListRef = c10::detail::MaterializedIListRef<at::Tensor>;
// [Note: IOptTensorListRef]
using IOptTensorListRef = c10::IListRef<at::OptionalTensorRef>;
using IOptTensorListRefIterator = c10::IListRefIterator<at::OptionalTensorRef>;
using MaterializedIOptTensorListRef = c10::detail::MaterializedIListRef<at::OptionalTensorRef>;
} // namespace at
Back to Directory
File Manager