diff --git a/Plugins/StyleTransfer/Content/NN_StylePredictor.uasset b/Plugins/StyleTransfer/Content/NN_StylePredictor.uasset new file mode 100644 index 00000000..65f999e0 --- /dev/null +++ b/Plugins/StyleTransfer/Content/NN_StylePredictor.uasset @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5ffdd6afea702024f6fa2d00f2dbf90bef02fc307e61c9382b0cbeba21282059 +size 4092053 diff --git a/Plugins/StyleTransfer/Content/NN_StyleTransfer.uasset b/Plugins/StyleTransfer/Content/NN_StyleTransfer.uasset new file mode 100644 index 00000000..7b862770 --- /dev/null +++ b/Plugins/StyleTransfer/Content/NN_StyleTransfer.uasset @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c80d8e6510bae82db869b7b0175b20ee71163b8de7dec675eead1c282944129 +size 6700104 diff --git a/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf b/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf index a301fb59..e0af32a0 100644 --- a/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf +++ b/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf @@ -7,7 +7,7 @@ RWTexture2D OutputTexture; Buffer InputTensor; [numthreads(1, 1, 1)] -void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID) +void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID) { const uint GlobalIndex = DispatchThreadID.x; @@ -26,5 +26,10 @@ void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThre const uint MinorIndex = PixelIndex % TextureSize.x; const uint2 TextureCoords = uint2(MajorIndex, MinorIndex); - OutputTexture[TextureCoords][ChannelIndex] = InputTensor[GlobalIndex]; + OutputTexture[TextureCoords] = float4( + InputTensor[GlobalIndex + 0], + InputTensor[GlobalIndex + 1], + InputTensor[GlobalIndex + 2], + 1.0f + ); } diff --git a/Plugins/StyleTransfer/Source/StyleTransfer/Private/StyleTransferSubsystem.cpp b/Plugins/StyleTransfer/Source/StyleTransfer/Private/StyleTransferSubsystem.cpp index aff9bd39..e46b26b0 100644 --- a/Plugins/StyleTransfer/Source/StyleTransfer/Private/StyleTransferSubsystem.cpp +++ b/Plugins/StyleTransfer/Source/StyleTransfer/Private/StyleTransferSubsystem.cpp @@ -10,10 +10,10 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection) { Super::Initialize(Collection); - StyleTransferNetwork = NewObject(); + StyleTransferNetwork = LoadObject(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer")); + StylePredictionNetwork = LoadObject(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor")); - FString ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx"); - if (StyleTransferNetwork->Load(ONNXModelFilePath)) + if (StyleTransferNetwork->IsLoaded()) { if (StyleTransferNetwork->IsGPUSupported()) { @@ -26,13 +26,11 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection) } else { - UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not loaded from %s."), *ONNXModelFilePath); + UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not be loaded")); } - StylePredictionNetwork = NewObject(); - ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx"); - if (StylePredictionNetwork->Load(ONNXModelFilePath)) + if (StylePredictionNetwork->IsLoaded()) { if (StylePredictionNetwork->IsGPUSupported()) { @@ -45,16 +43,38 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection) } else { - UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not loaded from %s."), *ONNXModelFilePath); + UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not be loaded.")); } } +void UStyleTransferSubsystem::Deinitialize() +{ + StyleTransferSceneViewExtension.Reset(); + StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext); + + Super::Deinitialize(); +} + void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient) { StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension(ViewportClient, StyleTransferNetwork); + StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext(); } void UStyleTransferSubsystem::UpdateStyle(FNeuralTensor StyleImage) { + StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy()); + ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList) + { + FRDGBuilder GraphBuilder(RHICommandList); + + StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext); + + // @todo: copy output of style prediction network to input of style transfer network + + GraphBuilder.Execute(); + }); + + FlushRenderingCommands(); } diff --git a/Plugins/StyleTransfer/Source/StyleTransfer/Public/StyleTransferSubsystem.h b/Plugins/StyleTransfer/Source/StyleTransfer/Public/StyleTransferSubsystem.h index 6436c3b7..1686e2f2 100644 --- a/Plugins/StyleTransfer/Source/StyleTransfer/Public/StyleTransferSubsystem.h +++ b/Plugins/StyleTransfer/Source/StyleTransfer/Public/StyleTransferSubsystem.h @@ -20,6 +20,7 @@ class STYLETRANSFER_API UStyleTransferSubsystem : public UGameInstanceSubsystem public: // - UGameInstanceSubsystem virtual void Initialize(FSubsystemCollectionBase& Collection) override; + virtual void Deinitialize() override; // -- void StartStylizingViewport(FViewportClient* ViewportClient); @@ -34,4 +35,6 @@ private: UPROPERTY() TObjectPtr StylePredictionNetwork; + + int32 StylePredictionInferenceContext = -1; }; diff --git a/Plugins/StyleTransfer/Source/StyleTransferShaders/Private/OutputTensorToSceneColorCS.cpp b/Plugins/StyleTransfer/Source/StyleTransferShaders/Private/OutputTensorToSceneColorCS.cpp index f9791020..bf2589f0 100644 --- a/Plugins/StyleTransfer/Source/StyleTransferShaders/Private/OutputTensorToSceneColorCS.cpp +++ b/Plugins/StyleTransfer/Source/StyleTransferShaders/Private/OutputTensorToSceneColorCS.cpp @@ -10,6 +10,4 @@ void FOutputTensorToSceneColorCS::ModifyCompilationEnvironment(const FGlobalShad FGlobalShader::ModifyCompilationEnvironment(Parameters, OutEnvironment); } -IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS, - "/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf", - "OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp +IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS, "/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf", "OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp \ No newline at end of file diff --git a/Plugins/StyleTransfer/StyleTransfer.uplugin b/Plugins/StyleTransfer/StyleTransfer.uplugin index e9798c49..67ed84fa 100644 --- a/Plugins/StyleTransfer/StyleTransfer.uplugin +++ b/Plugins/StyleTransfer/StyleTransfer.uplugin @@ -17,7 +17,7 @@ { "Name": "StyleTransfer", "Type": "Runtime", - "LoadingPhase": "PostConfigInit", + "LoadingPhase": "Default", "WhitelistPlatforms": [ "Win64" ]