imported and loaded ONNX files into unreal

This commit is contained in:
Manuel Wagner 2022-08-23 18:43:36 +02:00
parent 6307ec1337
commit d91173c3ae
7 changed files with 46 additions and 14 deletions

BIN
Plugins/StyleTransfer/Content/NN_StylePredictor.uasset (Stored with Git LFS) Normal file

Binary file not shown.

BIN
Plugins/StyleTransfer/Content/NN_StyleTransfer.uasset (Stored with Git LFS) Normal file

Binary file not shown.

View File

@ -7,7 +7,7 @@ RWTexture2D<float4> OutputTexture;
Buffer<float> InputTensor;
[numthreads(1, 1, 1)]
void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{
const uint GlobalIndex = DispatchThreadID.x;
@ -26,5 +26,10 @@ void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThre
const uint MinorIndex = PixelIndex % TextureSize.x;
const uint2 TextureCoords = uint2(MajorIndex, MinorIndex);
OutputTexture[TextureCoords][ChannelIndex] = InputTensor[GlobalIndex];
OutputTexture[TextureCoords] = float4(
InputTensor[GlobalIndex + 0],
InputTensor[GlobalIndex + 1],
InputTensor[GlobalIndex + 2],
1.0f
);
}

View File

@ -10,10 +10,10 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
{
Super::Initialize(Collection);
StyleTransferNetwork = NewObject<UNeuralNetwork>();
StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer"));
StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor"));
FString ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx");
if (StyleTransferNetwork->Load(ONNXModelFilePath))
if (StyleTransferNetwork->IsLoaded())
{
if (StyleTransferNetwork->IsGPUSupported())
{
@ -26,13 +26,11 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
}
else
{
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not loaded from %s."), *ONNXModelFilePath);
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not be loaded"));
}
StylePredictionNetwork = NewObject<UNeuralNetwork>();
ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx");
if (StylePredictionNetwork->Load(ONNXModelFilePath))
if (StylePredictionNetwork->IsLoaded())
{
if (StylePredictionNetwork->IsGPUSupported())
{
@ -45,16 +43,38 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
}
else
{
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not loaded from %s."), *ONNXModelFilePath);
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not be loaded."));
}
}
void UStyleTransferSubsystem::Deinitialize()
{
StyleTransferSceneViewExtension.Reset();
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
Super::Deinitialize();
}
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
{
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork);
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
}
void UStyleTransferSubsystem::UpdateStyle(FNeuralTensor StyleImage)
{
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
// @todo: copy output of style prediction network to input of style transfer network
GraphBuilder.Execute();
});
FlushRenderingCommands();
}

View File

@ -20,6 +20,7 @@ class STYLETRANSFER_API UStyleTransferSubsystem : public UGameInstanceSubsystem
public:
// - UGameInstanceSubsystem
virtual void Initialize(FSubsystemCollectionBase& Collection) override;
virtual void Deinitialize() override;
// --
void StartStylizingViewport(FViewportClient* ViewportClient);
@ -34,4 +35,6 @@ private:
UPROPERTY()
TObjectPtr<UNeuralNetwork> StylePredictionNetwork;
int32 StylePredictionInferenceContext = -1;
};

View File

@ -10,6 +10,4 @@ void FOutputTensorToSceneColorCS::ModifyCompilationEnvironment(const FGlobalShad
FGlobalShader::ModifyCompilationEnvironment(Parameters, OutEnvironment);
}
IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS,
"/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf",
"OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp
IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS, "/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf", "OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp

View File

@ -17,7 +17,7 @@
{
"Name": "StyleTransfer",
"Type": "Runtime",
"LoadingPhase": "PostConfigInit",
"LoadingPhase": "Default",
"WhitelistPlatforms": [
"Win64"
]