Spaces:
Running
Running
File size: 3,486 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
// Copyright © 2022 Apple Inc.
// This file is modify from:
// https://github.com/pytorch/pytorch/blob/a85d1f0bcdd02cf18d3b0517337458cb51a18cdb/aten/src/ATen/mps/MPSStream.h
#pragma once
#include <cstdint>
#include <utility>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
#include "MPSDevice.h"
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
typedef id<MTLCommandQueue> MTLCommandQueue_t;
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
typedef id<MTLSharedEvent> MTLSharedEvent_t;
typedef id<MTLDevice> MTLDevice_t;
#else
typedef void* MTLCommandQueue_t;
typedef void* MTLCommandQueue;
typedef void* MTLCommandBuffer_t;
typedef void* MTLCommandBuffer;
typedef void* MTLSharedEvent_t;
typedef void* dispatch_queue_t;
typedef void* MTLDevice_t;
#define nil NULL;
#endif
namespace at {
namespace mps {
//-----------------------------------------------------------------
// MPSStream
//-----------------------------------------------------------------
class TORCH_API MPSStream {
public:
enum Unchecked { UNCHECKED };
/// Construct a MPSStream from a Stream. This construction is checked,
/// and will raise an error if the Stream is not, in fact, a MPS stream.
explicit MPSStream(Stream stream);
~MPSStream();
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
dispatch_queue_t queue() const { return _serialQueue; }
MTLCommandBuffer_t commandBuffer();
void commit(bool flush);
void commitAndWait();
void synchronize();
void flush();
/// Get the MPS device index that this stream is associated with.
c10::DeviceIndex device_index() const { return _stream.device_index(); }
MTLCommandQueue_t stream() const { return _commandQueue; };
MTLDevice_t device() const { return [_commandQueue device]; }
/// Explicit conversion to Stream.
Stream unwrap() const { return _stream; }
private:
Stream _stream;
MTLCommandQueue_t _commandQueue = nil;
MTLCommandBuffer_t _commandBuffer = nil;
void _flush(bool commitAndWait) const;
dispatch_queue_t _serialQueue = nullptr;
};
/**
* Get the current MPS stream
*/
TORCH_API MPSStream* getCurrentMPSStream();
/**
* Get the default MPS stream
*/
TORCH_API MPSStream* getDefaultMPSStream();
//-----------------------------------------------------------------
// MPSStreamImpl
//-----------------------------------------------------------------
class TORCH_API MPSStreamImpl {
public:
/**
* Gets single instance of the MPSStream.
*/
static MPSStream* getInstance();
private:
static MPSStream* _stream;
MPSStreamImpl();
};
//-----------------------------------------------------------------
// MPSEvent
//-----------------------------------------------------------------
struct TORCH_API MPSEvent {
MPSEvent();
// MPSEvent(id<MTLDevice> device);
~MPSEvent();
MTLSharedEvent_t event() const { return _event; }
void recordEvent(MPSStream* stream);
void waitForEvent(MPSStream* queue); // waits on the cpu
bool queryEvent();
uint64_t getCurrentValue() { return _currentValue; }
void setCurrentValue(uint64_t currValue) { _currentValue = currValue; }
private:
bool _isRecorded = false;
uint64_t _currentValue = 0;
MTLSharedEvent_t _event;
};
typedef MPSEvent* mpsEvent_t;
} // namespace mps
} // namespace at
|