You might need to perform tensor operations on your data or integrate them into your image processing pipeline. Metal provides Metal Performance Shaders Graph (MPSGraph) to handle such tasks. It offers a vast library of tensor operations that can be organized as a graph, eliminating the need for extensive boilerplate code.
If you have experience with frameworks like Keras or PyTorch, you're likely familiar with how computational graphs are constructed. If not, the concept is straightforward: a graph represents a sequence of operations applied to an input tensor, which eventually produces an output tensor. Using MPSGraph, you can train or pre-train your ML models directly and then seamlessly integrate them into your image, data processing, or rendering pipelines.
While CoreML excels in supporting a broader range of machine learning operations and offers numerous tools for converting models from other frameworks, performance-critical applications might benefit from MPSGraph. It allows you to work directly with Metal resources, avoiding the overhead of multiple intermediate data conversions.
Importantly, MPSGraph isn't just for machine learning - it's a valuable tool for any image processing algorithm that can represent image data as tensors.
Before diving into how to use MPSGraph, let's first explore the operations it provides:
for loops and conditional statements (if).LSTM and GRU layers.For a comprehensive list of supported operations and their details, refer to the official documentation.
First, you need to understand how a graph is constructed:
MPSGraph, a tensor is an n-dimensional array that serves as the fundamental data structure for computation.Understanding these concepts is crucial as they form the foundation for building and manipulating computational graphs in MPSGraph.
MPSGraph is just a representation, so it doesn't require a Metal device for initialization:
var graph = MPSGraph()
If you need to simply run your graph without additional operations, you can do so with the following straightforward method:
let fetch = graph.run( // 1
feeds: [inputTensor: input], // 2
targetTensors: [output], // 3
targetOperations: nil) // 4
fetch is a dictionary with reulting data.inputTensor placeholder and have some data in the input.You also can run the graph in a MTLCommandQueue.
If you need to perform more than one operation (for example data prepartion, pre-/post-processing, blitting, etc), you can use an MPSCommandBuffer, which is created from an MTLCommandQueue, and encode your graph there.
var commandQueue: MTLCommandQueue
// ... initalise the queue
let commandBuffer = MPSCommandBuffer(from: commandQueue)
// ... initialise `input`, `output`, `inputData`
let fetch = graph.encode(to: cmdBuf,
feeds: [input: inputData],
targetTensors: [output],
targetOperations: nil,
executionDescriptor: nil)
// ... parsing `fetch[output]`
commandBuffer.commit()
commandBuffer.waitUntilCompleted() // optional if you don't need to read result
Alternatively, you can base the MPSCommandBuffer on an existing MTLCommandBuffer:
var mtlCommandBuffer: MTLCommandBuffer
// ... initialise the Metal command buffer
let mpsCommandBuffer = MPSCommandBuffer(from: commandQueue)
// ... encoding graph
mpsCommandBuffer.commit()
mpsCommandBuffer.waitUntilCompleted() // optional if you don't need to read result
// ... further Metal command buffer processing
ATTENTION: There are nuances to MPSCommandBuffer that aren’t fully covered in the documentation. Here are some key points:
- Once we create this MPSCommandBuffer, any methods utilizing it could call commitAndContinue and so the users original commandBuffer may have been committed.
- Please use the rootCommandBuffer method to get the current alive underlying MTLCommandBuffer.
commitAndContinue()commits the underlying root MTLCommandBuffer, and makes a new one on the same command queue. The MPS heap is moved forward to the new command buffer such that temporary objects used by the previous command buffer can be still be used with the new one.- This provides a way to move work already encoded into consideration by the Metal back end sooner. For large workloads, e.g. a neural networking graph periodically calling
commitAndContinuemay allow you to improve CPU / GPU parallelism without the substantial memory increases associated with double buffering. It will also help decrease overall latency.- Any Metal schedule or completion callbacks previously attached to this object will remain attached to the old command buffer and will fire as expected as the old command buffer is scheduled and completes. If your application is relying on such callbacks to coordinate retain / release of important objects that are needed for work encoded after
commitAndContinue, your application should retain these objects BEFORE callingcommitAndContinue, and attach new release callbacks to this object with a new completion handler so that they persist through the lifetime of the new underlying command buffer. You may do this, for example by adding the objects to a mutable array before callingcommitAndContinue, then release the mutable array in a new completion callback added aftercommitAndContinue.- Because
commitAndContinuecommits the old command buffer then switches to a new one, some aspects of command buffer completion may surprise unwary developers. For example,waitUntilCompletedcalled immediately aftercommitAndContinueasks Metal to wait for the new command buffer to finish, not the old one. Since the new command buffer presumably hasn't been committed yet, it is formally a deadlock, resources may leak and Metal may complain. Your application should ether callcommitbeforewaitUntilCompleted, or capture therootCommandBufferfrom before the call tocommitAndContinueand wait on that. Similarly, your application should be sure to use the appropriate command buffer when querying theMTLCommandBuffer.statusproperty.- If the underlying MTLCommandBuffer also implements
commitAndContinue, then the message will be forwarded to that object instead. In this way, underlying predicate objects and other state will be preserved.
The following example breaks down the behavior of MPSCommandBuffer, highlighting how it manages the underlying Metal command buffer and its implications:
let mtlCmdBuf = commandQueue.makeCommandBuffer()! // 1
let mpsCmdBuf = MPSCommandBuffer(commandBuffer: mtlCmdBuf) // 2
mpsCmdBuf.rootCommandBuffer.addCompletedHandler {_ in // 3
print("Completed the 1st command buffer")
}
graph.encode(to: mpsCmdBuf, ...) // 4
mpsCmdBuf.rootCommandBuffer.addCompletedHandler {_ in
print("Completed the 2nd command buffer")
}
mpsCmdBuf.commit() // 5
mtlCmdBuf) is created from the command queue (or is passed from outside). This is the initial command buffer that serves as the basis for the MPSCommandBuffer.MPSCommandBuffer wraps around the existing Metal command buffer (mtlCmdBuf). At this point, the mpsCmdBuf.rootCommandBuffer references the original mtlCmdBuf.mtlCmdBuf. This handler will execute when the command buffer completes execution.MPSCommandBuffer (mpsCmdBuf) is implicitly committed with commitAndContinue(). This also commits the underlying mtlCmdBuf. The MPSCommandBuffer then creates a new internal Metal command buffer, replacing the previous one. After this step:mtlCmdBuf’s status changes from notEnqueued to committed.mpsCmdBuf.rootCommandBuffer.mpsCmdBuf.commit() commits the new underlying Metal command buffer created in step (4).The output will be:
Completed the 1st command buffer
Completed the 2nd command buffer
ATTENTION: This behavior ensures that
MPSGraphoperations are efficiently managed, but it requires careful handling of command buffer states in more intricate processing pipelines. Be mindful of this behavior when building complex pipelines, as implicit commits can disrupt your workflow if you rely on the original command buffer.
I would recommend using asserts to check the MTLCommandBuffer.status during development to detect any problems early in the pipeline.
You also can compile your graph into a MPSGraphExecutable with fixed input and outputs. But it's another big theme.
We've already discussed integrating running MPSGraph in a MTLCommandQueue and nuances about operating MTLCommandBuffers, but what about transferring data between vanilla Metal and MPSGraph?
Although we can't directly use MTLTexture or MTLBuffer with MPSGraph, we can perform the following conversions:
MTLBuffer to MPSGraphTensorData: Data stored in a Metal buffer can be wrapped into an MPSGraphTensorData object directly. This allows the buffer's contents to be used as input tensors for the graph.
let tensorData = MPSGraphTensorData(mtlBuffer: buffer, shape: shape, dataType: .float32)
MTLTexture to MPSImage to MPSGraphTensorData: Metal textures can be converted into MPSImage objects. From there, you can wrap the image into MPSGraphTensorData for graph usage.
let mpsImage = MPSImage(texture: texture, featureChannels: featureChannels)
let tensorData = MPSGraphTensorData([mpsImage])
ATTENTION:
- When converting
MTLBufferorMTLTexturetoMPSGraphTensorData, ensure that the data layout and shapes align with the expected tensor format in your graph.- Use GPU-side data export (e.g., from
MPSNDArraytoMTLBufferorMTLTexture) whenever possible to minimize CPU-GPU synchronization overhead.
Now let's design and build a simple graph that performs basic automatic image enhancements:

As we can see, the graph has some nodes (e.g., RGB to YUV, YUV to RGB) or even groups of nodes (e.g., Y normalization, UV shift) that aren't directly provided by the original MPSGraph but can be encapsulated as separate modules or subgraphs. To improve modularity and reusability, we can start by implementing these isolated subgraphs.
As we want to use the new node in our bigger graph, we'll just extend MPSGraph to include the RGB-to-YUV conversion using existing operations:
extension MPSGraph {
func rgb2yuv(rgbTensor: MPSGraphTensor) -> MPSGraphTensor { // 1
let rgbToYUVMatrixData = [Float]([ // 2
0.299, -0.14713, 0.615,
0.587, -0.28886, -0.51499,
0.114, 0.436, -0.10001
]).withUnsafeBufferPointer {
Data(buffer: $0)
}
let rgbToYUVMatrix = constant( // 3
rgbToYUVMatrixData,
shape: [3, 3],
dataType: .float32)
let yuvTensor = matrixMultiplication( // 4
primary: rgbTensor,
secondary: rgbToYUVMatrix,
name: "rgb2yuv")
return yuvTensor
}
}
MPSGraph: We add a new method rgb2yuv(rgbTensor:) to handle RGB-to-YUV conversion as part of the graph. The new method takes a tensor representing a source RGB image and returns a tensor representing the YUV image.Data object. Note that Metal uses column-major matrix definitions, so the matrix data must be arranged accordingly.[3, 3] and data type float32.matrixMultiplication operation, each "pixel" in the input RGB tensor is multiplied by the transformation matrix to produce the output YUV tensor. This operation applies the RGB-to-YUV transformation to the entire input image efficiently.In the same way we can define YUV-to-RGB operation:
extension MPSGraph {
func yuv2rgb(yuvTensor: MPSGraphTensor) -> MPSGraphTensor {
let yuvToRGBMatrixData = [Float]([
1.0, 1.0, 1.0,
0.0, -0.39465, 2.03211,
1.13983, -0.58060, 0.0
]).withUnsafeBufferPointer {
Data(buffer: $0)
}
let yuvToRGBMatrix = constant(
yuvToRGBMatrixData,
shape: [3, 3],
dataType: .float32)
let rgbTensor = matrixMultiplication(
primary: yuvTensor,
secondary: yuvToRGBMatrix,
name: "yuv2rgb")
return rgbTensor
}
}
The idea of this part is very straightforward: we get the minimal and maximal values of the Y channel in our image and normalize the entire image by mapping this range to the interval [0, 1]. This process ensures that the lightness values in the image are evenly distributed, improving contrast and brightness while preserving the relative relationships between pixels.
extension MPSGraph {
func normalize(input: MPSGraphTensor) -> MPSGraphTensor {
let minVal = reductionMinimum(with: input, axes: [1, 2], name: "minVal")
let maxVal = reductionMaximum(with: input, axes: [1, 2], name: "maxVal")
let normalized = division(
subtraction(input, minVal, name: "YCentered"),
subtraction(maxVal, minVal, name: "range"),
name: "YNormalized"
)
return normalized
}
}
Here we assume that the average value of an image with good white balance should be gray. To achieve this, we compute the average value of all pixels in the image and offset each pixel by this value. This adjustment ensures that the overall color balance of the image is neutral, correcting any unwanted tints or biases caused by imbalanced colors.
extension MPSGraph {
func meanShift(input: MPSGraphTensor) -> MPSGraphTensor {
let average = mean(of: input, axes: [1, 2], name: "Average")
let shifted = subtraction(input, average, name: "Shifted")
return shifted
}
}
Now that we have most of our components ready, we can assemble the entire graph. This includes all steps of the image processing pipeline, such as slicing the RGBA input into RGB and Alpha channels, performing enhancements on the RGB channels, and finally merging them back into a complete RGBA tensor.
func buildGraph(graph: MPSGraph, rgbaTensor: MPSGraphTensor) -> MPSGraphTensor {
let rgbTensor = graph.sliceTensor(
rgbaTensor,
dimension: -1,
start: 0,
length: 3,
name: "RGB"
)
let yuvTensor = graph.rgb2yuv(rgbTensor: rgbTensor)
let yChannel = graph.sliceTensor(
yuvTensor,
dimension: -1,
start: 0,
length: 1,
name: "YChannel"
)
let yNormalized = graph.normalize(input: yChannel)
let uvChannels = graph.sliceTensor(
yuvTensor,
dimension: -1,
start: 1,
length: 2,
name: "UVChannels"
)
let shiftedUV = graph.meanShift(input: uvChannels)
let normalizedYUV = graph.concatTensors(
[yNormalized, shiftedUV],
dimension: -1,
name: "NormalizedShiftedYUV")
let normalizedRGB = graph.yuv2rgb(yuvTensor: normalizedYUV)
let alphaChannel = graph.sliceTensor(
rgbaTensor,
dimension: -1,
start: 3,
length: 1,
name: "AlphaChannel"
)
let normalizedRGBA = graph.concatTensors(
[normalizedRGB, alphaChannel],
dimension: -1,
name: "NormalizedShiftedRGBA")
return normalizedRGBA
}
func run(image: MPSImage, commandQueue: MTLCommandQueue) async -> MPSImage? {
// 1
let inputData = MPSGraphTensorData([image])
// 2
let cmdBuf = MPSCommandBuffer(from: commandQueue)
// 3
let graph = MPSGraph()
// 4
let input = graph.placeholder(shape: [-1, -1, -1, -1],
dataType: MPSDataType.float32,
name: "input")
// 5
let output = buildGraph(graph: graph, rgbaTensor: input)
// 6
let fetch = graph.encode(to: cmdBuf,
feeds: [input: inputData],
targetTensors: [output],
targetOperations: nil,
executionDescriptor: nil)
// 7
var result: MPSImage?
if let resArray = fetch[output]?.mpsndarray() {
let resDesc = resArray.descriptor()
let imgDesc = MPSImageDescriptor(
channelFormat: MPSImageFeatureChannelFormat.float16,
width: resDesc.sliceRange(forDimension: 1).length,
height: resDesc.sliceRange(forDimension: 2).length,
featureChannels: 4)
let resImage = MPSImage(device: image.device, imageDescriptor: imgDesc)
// 8
resArray.exportData(with: cmdBuf,
to: [resImage],
offset: MPSImageCoordinate(x: 0, y: 0, channel: 0))
result = resImage
}
// 9
cmdBuf.commit()
cmdBuf.waitUntilCompleted()
return result
}
MPSImage into MPSGraphTensorData.MPSCommandBuffer from the Metal CommandQueue.MPSGraph object for constructing the computation graph.buildGraph function from previous section.MPSCommandBuffer for execution. We could just run it, but we have a blitting operation below.MPSNDArray and convert it back to an MPSImage.MPSNDArray back into the MPSImage. This operation performs on GPU side.As a result of running the graph, we get the following result (left is an original image):

When you perform a GPU capture in Xcode, you can get an approximate view of what’s happening under the hood of MPSGraph. While you cannot see the specific optimizations or exact implementation details of your graph (as MPSGraph handles this internally), the GPU capture allows you to inspect important aspects like memory consumption, resource flow, and overall execution timeline.
NOTE: you can see, that there's lots of resources allocated by MPSGraph and which are out of your control. That could significantly impact memory consumption of your app.

MPSGraph provides a streamlined approach for performing tensor operations without the extensive boilerplate code required by vanilla Metal.MPSGraph can be integrated into Metal pipelines, there are certain nuances and limitations, especially regarding resource management and command buffer handling.