Executing multiple ML models in a chain
Windows ML supports high-performance load and execution of model chains by carefully optimizing its GPU path. Model chains are defined by two or more models that execute sequentially, where the outputs of one model become the inputs to the next model down the chain.
In order to explain how to efficiently chain models with Windows ML, let's use a FNS-Candy Style Transfer ONNX model as a toy example. You can find this type of model in the FNS-Candy Style Transfer sample folder in our GitHub.
Let's say we want to execute a chain that is composed of two instances of the same FNS-Candy model, here called mosaic.onnx. The application code would pass an image to the first model in the chain, let it compute the outputs, and then pass that transformed image to another instance of FNS-Candy, producing a final image.
The following steps illustrate how to accomplish that using Windows ML.
Note
In a real word scenario you most likely would use two different models, but this should be enough to illustrate the concepts.
- First, let's load the mosaic.onnx model so that we can use it.
std::wstring filePath = L"path\\to\\mosaic.onnx";
LearningModel model = LearningModel::LoadFromFilePath(filePath);
string filePath = "path\\to\\mosaic.onnx";
LearningModel model = LearningModel.LoadFromFilePath(filePath);
- Then, let's create two identical sessions on the device's default GPU using the same model as input parameter.
LearningModelSession session1(model, LearningModelDevice(LearningModelDeviceKind::DirectX));
LearningModelSession session2(model, LearningModelDevice(LearningModelDeviceKind::DirectX));
LearningModelSession session1 =
new LearningModelSession(model, new LearningModelDevice(LearningModelDeviceKind.DirectX));
LearningModelSession session2 =
new LearningModelSession(model, new LearningModelDevice(LearningModelDeviceKind.DirectX));
Note
In order to reap the performance benefits of chaining, you need to create identical GPU sessions for all of your models. Not doing so would result in additional data movement out of the GPU and into the CPU, which would reduce performance.
- The following lines of code will create bindings for each session:
LearningModelBinding binding1(session1);
LearningModelBinding binding2(session2);
LearningModelBinding binding1 = new LearningModelBinding(session1);
LearningModelBinding binding2 = new LearningModelBinding(session2);
- Next, we will bind an input for our first model. We will pass in an image that is located in the same path as our model. In this example the image is called "fish_720.png".
//get the input descriptor
ILearningModelFeatureDescriptor input = model.InputFeatures().GetAt(0);
//load a SoftwareBitmap
hstring imagePath = L"path\\to\\fish_720.png";
// Get the image and bind it to the model's input
try
{
StorageFile file = StorageFile::GetFileFromPathAsync(imagePath).get();
IRandomAccessStream stream = file.OpenAsync(FileAccessMode::Read).get();
BitmapDecoder decoder = BitmapDecoder::CreateAsync(stream).get();
SoftwareBitmap softwareBitmap = decoder.GetSoftwareBitmapAsync().get();
VideoFrame videoFrame = VideoFrame::CreateWithSoftwareBitmap(softwareBitmap);
ImageFeatureValue image = ImageFeatureValue::CreateFromVideoFrame(videoFrame);
binding1.Bind(input.Name(), image);
}
catch (...)
{
printf("Failed to load/bind image\n");
}
//get the input descriptor
ILearningModelFeatureDescriptor input = model.InputFeatures[0];
//load a SoftwareBitmap
string imagePath = "path\\to\\fish_720.png";
// Get the image and bind it to the model's input
try
{
StorageFile file = await StorageFile.GetFileFromPathAsync(imagePath);
IRandomAccessStream stream = await file.OpenAsync(FileAccessMode.Read);
BitmapDecoder decoder = await BitmapDecoder.CreateAsync(stream);
SoftwareBitmap softwareBitmap = await decoder.GetSoftwareBitmapAsync();
VideoFrame videoFrame = VideoFrame.CreateWithSoftwareBitmap(softwareBitmap);
ImageFeatureValue image = ImageFeatureValue.CreateFromVideoFrame(videoFrame);
binding1.Bind(input.Name, image);
}
catch
{
Console.WriteLine("Failed to load/bind image");
}
- In order for the next model in the chain to use the outputs of the evaluation of the first model, we need to create an empty output tensor and bind the output so we have a marker to chain with:
//get the output descriptor
ILearningModelFeatureDescriptor output = model.OutputFeatures().GetAt(0);
//create an empty output tensor
std::vector<int64_t> shape = {1, 3, 720, 720};
TensorFloat outputValue = TensorFloat::Create(shape);
//bind the (empty) output
binding1.Bind(output.Name(), outputValue);
//get the output descriptor
ILearningModelFeatureDescriptor output = model.OutputFeatures[0];
//create an empty output tensor
List<long> shape = new List<long> { 1, 3, 720, 720 };
TensorFloat outputValue = TensorFloat.Create(shape);
//bind the (empty) output
binding1.Bind(output.Name, outputValue);
Note
You must use the TensorFloat data type when binding the output. This will prevent de-tensorization from occurring once evaluation for the first model is completed, therefore also avoiding additional GPU queueing for load and bind operations for the second model.
- Now, we run the evaluation of the first model, and bind its outputs to the next model's input:
//run session1 evaluation
session1.EvaluateAsync(binding1, L"");
//bind the output to the next model input
binding2.Bind(input.Name(), outputValue);
//run session2 evaluation
auto session2AsyncOp = session2.EvaluateAsync(binding2, L"");
//run session1 evaluation
await session1.EvaluateAsync(binding1, "");
//bind the output to the next model input
binding2.Bind(input.Name, outputValue);
//run session2 evaluation
LearningModelEvaluationResult results = await session2.EvaluateAsync(binding2, "");
- Finally, let's retrieve the final output produced after running both models by using the following line of code.
auto finalOutput = session2AsyncOp.get().Outputs().First().Current().Value();
var finalOutput = results.Outputs.First().Value;
That's it! Both your models now can execute sequentially by making the most of the available GPU resources.
Note
Use the following resources for help with Windows ML:
- To ask or answer technical questions about Windows ML, please use the windows-machine-learning tag on Stack Overflow.
- To report a bug, please file an issue on our GitHub.