imported and loaded ONNX files into unreal

This commit is contained in:
Manuel Wagner 2022-08-23 18:43:36 +02:00
parent 6307ec1337
commit d91173c3ae
7 changed files with 46 additions and 14 deletions

BIN
Plugins/StyleTransfer/Content/NN_StylePredictor.uasset (Stored with Git LFS) Normal file

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/NN_StyleTransfer.uasset (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -7,7 +7,7 @@ RWTexture2D<float4> OutputTexture;
Buffer<float> InputTensor; Buffer<float> InputTensor;
[numthreads(1, 1, 1)] [numthreads(1, 1, 1)]
void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID) void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{ {
const uint GlobalIndex = DispatchThreadID.x; const uint GlobalIndex = DispatchThreadID.x;
@ -26,5 +26,10 @@ void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThre
const uint MinorIndex = PixelIndex % TextureSize.x; const uint MinorIndex = PixelIndex % TextureSize.x;
const uint2 TextureCoords = uint2(MajorIndex, MinorIndex); const uint2 TextureCoords = uint2(MajorIndex, MinorIndex);
OutputTexture[TextureCoords][ChannelIndex] = InputTensor[GlobalIndex]; OutputTexture[TextureCoords] = float4(
InputTensor[GlobalIndex + 0],
InputTensor[GlobalIndex + 1],
InputTensor[GlobalIndex + 2],
1.0f
);
} }

View File

@ -10,10 +10,10 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
{ {
Super::Initialize(Collection); Super::Initialize(Collection);
StyleTransferNetwork = NewObject<UNeuralNetwork>(); StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer"));
StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor"));
FString ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx"); if (StyleTransferNetwork->IsLoaded())
if (StyleTransferNetwork->Load(ONNXModelFilePath))
{ {
if (StyleTransferNetwork->IsGPUSupported()) if (StyleTransferNetwork->IsGPUSupported())
{ {
@ -26,13 +26,11 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
} }
else else
{ {
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not loaded from %s."), *ONNXModelFilePath); UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not be loaded"));
} }
StylePredictionNetwork = NewObject<UNeuralNetwork>();
ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx"); if (StylePredictionNetwork->IsLoaded())
if (StylePredictionNetwork->Load(ONNXModelFilePath))
{ {
if (StylePredictionNetwork->IsGPUSupported()) if (StylePredictionNetwork->IsGPUSupported())
{ {
@ -45,16 +43,38 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
} }
else else
{ {
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not loaded from %s."), *ONNXModelFilePath); UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not be loaded."));
} }
} }
void UStyleTransferSubsystem::Deinitialize()
{
StyleTransferSceneViewExtension.Reset();
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
Super::Deinitialize();
}
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient) void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
{ {
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork); StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork);
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
} }
void UStyleTransferSubsystem::UpdateStyle(FNeuralTensor StyleImage) void UStyleTransferSubsystem::UpdateStyle(FNeuralTensor StyleImage)
{ {
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
// @todo: copy output of style prediction network to input of style transfer network
GraphBuilder.Execute();
});
FlushRenderingCommands();
} }

View File

@ -20,6 +20,7 @@ class STYLETRANSFER_API UStyleTransferSubsystem : public UGameInstanceSubsystem
public: public:
// - UGameInstanceSubsystem // - UGameInstanceSubsystem
virtual void Initialize(FSubsystemCollectionBase& Collection) override; virtual void Initialize(FSubsystemCollectionBase& Collection) override;
virtual void Deinitialize() override;
// -- // --
void StartStylizingViewport(FViewportClient* ViewportClient); void StartStylizingViewport(FViewportClient* ViewportClient);
@ -34,4 +35,6 @@ private:
UPROPERTY() UPROPERTY()
TObjectPtr<UNeuralNetwork> StylePredictionNetwork; TObjectPtr<UNeuralNetwork> StylePredictionNetwork;
int32 StylePredictionInferenceContext = -1;
}; };

View File

@ -10,6 +10,4 @@ void FOutputTensorToSceneColorCS::ModifyCompilationEnvironment(const FGlobalShad
FGlobalShader::ModifyCompilationEnvironment(Parameters, OutEnvironment); FGlobalShader::ModifyCompilationEnvironment(Parameters, OutEnvironment);
} }
IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS, IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS, "/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf", "OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp
"/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf",
"OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp

View File

@ -17,7 +17,7 @@
{ {
"Name": "StyleTransfer", "Name": "StyleTransfer",
"Type": "Runtime", "Type": "Runtime",
"LoadingPhase": "PostConfigInit", "LoadingPhase": "Default",
"WhitelistPlatforms": [ "WhitelistPlatforms": [
"Win64" "Win64"
] ]