Viewing File: /home/ubuntu/combine_ai/combine/lib/python3.10/site-packages/torch/include/ATen/mps/MPSGuardImpl.h

//  Copyright © 2022 Apple Inc.

#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/Context.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSEvent.h>

#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#endif

#include <ATen/Tensor.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorImpl.h>
#include <sys/_types/_size_t.h>
#include <memory>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/intrusive_ptr.h>


namespace at::mps {

typedef MPSEvent* mpsEvent_t;

// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
// https://github.com/pytorch/pytorch/issues/77170
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;

  // constructor
  MPSGuardImpl() {}
  explicit MPSGuardImpl(c10::DeviceType t) {
    TORCH_INTERNAL_ASSERT(t == c10::DeviceType::MPS);
  }

  // returns the type
  c10::DeviceType type() const override {
    return c10::DeviceType::MPS;
  }

  Device exchangeDevice(Device d) const override {
    return Device(c10::DeviceType::MPS, 0);
  }

  Device getDevice() const override {
    return Device(c10::DeviceType::MPS, 0);
  }

  c10::optional<Device> uncheckedGetDevice() const noexcept {
    return Device(c10::DeviceType::MPS, 0);
  }

  void setDevice(Device d) const override {
    TORCH_INTERNAL_ASSERT(d.is_mps());
  }

  void uncheckedSetDevice(Device d) const noexcept override {
    // TODO: Currently setting only device 0
  }

  Stream getStream(Device d) const noexcept override {
    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  }

  Stream getDefaultStream(Device d) const override {
    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  }

  // NB: These do NOT set the current device
  Stream exchangeStream(Stream s) const noexcept override {
    return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0));
  }
  DeviceIndex deviceCount() const noexcept override {
    if (at::hasMPS()) {
      //TODO: extend it for multi-device case
      return 1;
    } else {
      return 0;
    }
  }

  // Event-related functions
  void createEvent(
    mpsEvent_t* event,
    const EventFlag flag) const;

  void destroyEvent(
    void* event,
    const DeviceIndex device_index) const noexcept override;

  void record(
    void** event,
    const Stream& stream,
    const DeviceIndex device_index,
    const EventFlag flag) const override;

  void block(
    void* event,
    const Stream& stream) const override;

  bool queryEvent(void* event) const override;

};

/// A variant of OptionalDeviceGuard that is specialized for MPS.
struct OptionalMPSGuard {
  explicit OptionalMPSGuard() : guard_() {}

  explicit OptionalMPSGuard(c10::optional<Device> device_opt)
      : guard_(device_opt) {}

  /// Set the current MPS device to the passed device index, if it is not
  /// nullopt
  explicit OptionalMPSGuard(c10::optional<DeviceIndex> device_index_opt)
      : guard_(device_index_opt) {}

  // Copy is not allowed
  OptionalMPSGuard(const OptionalMPSGuard&) = delete;
  OptionalMPSGuard& operator=(const OptionalMPSGuard&) = delete;
  OptionalMPSGuard(OptionalMPSGuard&& other) = delete;
  OptionalMPSGuard& operator=(OptionalMPSGuard&& other) = delete;

  /// Sets the MPS device to the given device, initializing the guard if it
  /// is not already initialized.  Errors if the given device is not a MPS
  /// device.
  void set_device(Device device) {
    guard_.set_device(device);
  }

  /// Sets the MPS device to the given device, initializing the guard if it is
  /// not already initialized.  Errors if the given device is not a MPS device.
  void reset_device(Device device) {
    guard_.reset_device(device);
  }

  /// Sets the MPS device to the given device index, initializing the guard if
  /// it is not already initialized.
  void set_index(DeviceIndex device_index) {
    guard_.set_index(device_index);
  }

  /// Returns the device that was set immediately prior to initialization of the
  /// guard, or nullopt if the guard is uninitialized.
  c10::optional<Device> original_device() const {
    return guard_.original_device();
  }

  /// Returns the most recent device that was set using this device guard,
  /// either from construction, or via set_device, if the guard is initialized,
  /// or nullopt if the guard is uninitialized.
  c10::optional<Device> current_device() const {
    return guard_.current_device();
  }

  /// Restore the original MPS device, resetting this guard to uninitialized
  /// state.
  void reset() {
    guard_.reset();
  }

 private:
  c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
};


C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl);

} // namespace at::mps
Back to Directory File Manager