File size: 3,486 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
//  Copyright © 2022 Apple Inc.

// This file is modify from:
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h

#pragma once

#include <cstdint>
#include <utility>

#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
#include "MPSDevice.h"

#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
typedef id<MTLCommandQueue> MTLCommandQueue_t;
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
typedef id<MTLSharedEvent> MTLSharedEvent_t;
typedef id<MTLDevice> MTLDevice_t;
#else
typedef void* MTLCommandQueue_t;
typedef void* MTLCommandQueue;
typedef void* MTLCommandBuffer_t;
typedef void* MTLCommandBuffer;
typedef void* MTLSharedEvent_t;
typedef void* dispatch_queue_t;
typedef void* MTLDevice_t;
#define nil NULL;
#endif

namespace at {
namespace mps {

//-----------------------------------------------------------------
//  MPSStream
//-----------------------------------------------------------------

class TORCH_API MPSStream {
 public:
  enum Unchecked { UNCHECKED };
  /// Construct a MPSStream from a Stream.  This construction is checked,
  /// and will raise an error if the Stream is not, in fact, a MPS stream.
  explicit MPSStream(Stream stream);

  ~MPSStream();
  MTLCommandQueue_t commandQueue() const { return _commandQueue; };
  dispatch_queue_t queue() const { return _serialQueue; }

  MTLCommandBuffer_t commandBuffer();
  void commit(bool flush);
  void commitAndWait();
  void synchronize();

  void flush();

  /// Get the MPS device index that this stream is associated with.
  c10::DeviceIndex device_index() const { return _stream.device_index(); }

  MTLCommandQueue_t stream() const { return _commandQueue; };

  MTLDevice_t device() const { return [_commandQueue device]; }

  /// Explicit conversion to Stream.
  Stream unwrap() const { return _stream; }

 private:
  Stream _stream;
  MTLCommandQueue_t _commandQueue = nil;
  MTLCommandBuffer_t _commandBuffer = nil;
  void _flush(bool commitAndWait) const;

  dispatch_queue_t _serialQueue = nullptr;
};

/**
 * Get the current MPS stream
 */
TORCH_API MPSStream* getCurrentMPSStream();

/**
 * Get the default MPS stream
 */
TORCH_API MPSStream* getDefaultMPSStream();

//-----------------------------------------------------------------
//  MPSStreamImpl
//-----------------------------------------------------------------

class TORCH_API MPSStreamImpl {
 public:
  /**
   * Gets single instance of the MPSStream.
   */
  static MPSStream* getInstance();

 private:
  static MPSStream* _stream;
  MPSStreamImpl();
};

//-----------------------------------------------------------------
//  MPSEvent
//-----------------------------------------------------------------

struct TORCH_API MPSEvent {
  MPSEvent();
  // MPSEvent(id<MTLDevice> device);

  ~MPSEvent();
  MTLSharedEvent_t event() const { return _event; }

  void recordEvent(MPSStream* stream);
  void waitForEvent(MPSStream* queue);  // waits on the cpu
  bool queryEvent();
  uint64_t getCurrentValue() { return _currentValue; }
  void setCurrentValue(uint64_t currValue) { _currentValue = currValue; }

 private:
  bool _isRecorded = false;
  uint64_t _currentValue = 0;
  MTLSharedEvent_t _event;
};

typedef MPSEvent* mpsEvent_t;

}  // namespace mps
}  // namespace at