imported and loaded ONNX files into unreal
This commit is contained in:
parent
6307ec1337
commit
d91173c3ae
Binary file not shown.
Binary file not shown.
|
@ -7,7 +7,7 @@ RWTexture2D<float4> OutputTexture;
|
||||||
Buffer<float> InputTensor;
|
Buffer<float> InputTensor;
|
||||||
|
|
||||||
[numthreads(1, 1, 1)]
|
[numthreads(1, 1, 1)]
|
||||||
void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
|
void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
|
||||||
{
|
{
|
||||||
const uint GlobalIndex = DispatchThreadID.x;
|
const uint GlobalIndex = DispatchThreadID.x;
|
||||||
|
|
||||||
|
@ -26,5 +26,10 @@ void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThre
|
||||||
const uint MinorIndex = PixelIndex % TextureSize.x;
|
const uint MinorIndex = PixelIndex % TextureSize.x;
|
||||||
const uint2 TextureCoords = uint2(MajorIndex, MinorIndex);
|
const uint2 TextureCoords = uint2(MajorIndex, MinorIndex);
|
||||||
|
|
||||||
OutputTexture[TextureCoords][ChannelIndex] = InputTensor[GlobalIndex];
|
OutputTexture[TextureCoords] = float4(
|
||||||
|
InputTensor[GlobalIndex + 0],
|
||||||
|
InputTensor[GlobalIndex + 1],
|
||||||
|
InputTensor[GlobalIndex + 2],
|
||||||
|
1.0f
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,10 +10,10 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
|
||||||
{
|
{
|
||||||
Super::Initialize(Collection);
|
Super::Initialize(Collection);
|
||||||
|
|
||||||
StyleTransferNetwork = NewObject<UNeuralNetwork>();
|
StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer"));
|
||||||
|
StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor"));
|
||||||
|
|
||||||
FString ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx");
|
if (StyleTransferNetwork->IsLoaded())
|
||||||
if (StyleTransferNetwork->Load(ONNXModelFilePath))
|
|
||||||
{
|
{
|
||||||
if (StyleTransferNetwork->IsGPUSupported())
|
if (StyleTransferNetwork->IsGPUSupported())
|
||||||
{
|
{
|
||||||
|
@ -26,13 +26,11 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not loaded from %s."), *ONNXModelFilePath);
|
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not be loaded"));
|
||||||
}
|
}
|
||||||
|
|
||||||
StylePredictionNetwork = NewObject<UNeuralNetwork>();
|
|
||||||
|
|
||||||
ONNXModelFilePath = TEXT("SOME_PARENT_FOLDER/SOME_ONNX_FILE_NAME.onnx");
|
if (StylePredictionNetwork->IsLoaded())
|
||||||
if (StylePredictionNetwork->Load(ONNXModelFilePath))
|
|
||||||
{
|
{
|
||||||
if (StylePredictionNetwork->IsGPUSupported())
|
if (StylePredictionNetwork->IsGPUSupported())
|
||||||
{
|
{
|
||||||
|
@ -45,16 +43,38 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not loaded from %s."), *ONNXModelFilePath);
|
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not be loaded."));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UStyleTransferSubsystem::Deinitialize()
|
||||||
|
{
|
||||||
|
StyleTransferSceneViewExtension.Reset();
|
||||||
|
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
|
||||||
|
|
||||||
|
Super::Deinitialize();
|
||||||
|
}
|
||||||
|
|
||||||
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
|
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
|
||||||
{
|
{
|
||||||
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork);
|
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork);
|
||||||
|
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
|
||||||
}
|
}
|
||||||
|
|
||||||
void UStyleTransferSubsystem::UpdateStyle(FNeuralTensor StyleImage)
|
void UStyleTransferSubsystem::UpdateStyle(FNeuralTensor StyleImage)
|
||||||
{
|
{
|
||||||
|
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
|
||||||
|
|
||||||
|
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
|
||||||
|
{
|
||||||
|
FRDGBuilder GraphBuilder(RHICommandList);
|
||||||
|
|
||||||
|
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
|
||||||
|
|
||||||
|
// @todo: copy output of style prediction network to input of style transfer network
|
||||||
|
|
||||||
|
GraphBuilder.Execute();
|
||||||
|
});
|
||||||
|
|
||||||
|
FlushRenderingCommands();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ class STYLETRANSFER_API UStyleTransferSubsystem : public UGameInstanceSubsystem
|
||||||
public:
|
public:
|
||||||
// - UGameInstanceSubsystem
|
// - UGameInstanceSubsystem
|
||||||
virtual void Initialize(FSubsystemCollectionBase& Collection) override;
|
virtual void Initialize(FSubsystemCollectionBase& Collection) override;
|
||||||
|
virtual void Deinitialize() override;
|
||||||
// --
|
// --
|
||||||
|
|
||||||
void StartStylizingViewport(FViewportClient* ViewportClient);
|
void StartStylizingViewport(FViewportClient* ViewportClient);
|
||||||
|
@ -34,4 +35,6 @@ private:
|
||||||
|
|
||||||
UPROPERTY()
|
UPROPERTY()
|
||||||
TObjectPtr<UNeuralNetwork> StylePredictionNetwork;
|
TObjectPtr<UNeuralNetwork> StylePredictionNetwork;
|
||||||
|
|
||||||
|
int32 StylePredictionInferenceContext = -1;
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,6 +10,4 @@ void FOutputTensorToSceneColorCS::ModifyCompilationEnvironment(const FGlobalShad
|
||||||
FGlobalShader::ModifyCompilationEnvironment(Parameters, OutEnvironment);
|
FGlobalShader::ModifyCompilationEnvironment(Parameters, OutEnvironment);
|
||||||
}
|
}
|
||||||
|
|
||||||
IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS,
|
IMPLEMENT_GLOBAL_SHADER(FOutputTensorToSceneColorCS, "/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf", "OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp
|
||||||
"/Plugins/StyleTransfer/Shaders/Private/OutputTensorToSceneColor.usf",
|
|
||||||
"OutputTensorToSceneColorCS", SF_Compute); // Path defined in StyleTransferModule.cpp
|
|
|
@ -17,7 +17,7 @@
|
||||||
{
|
{
|
||||||
"Name": "StyleTransfer",
|
"Name": "StyleTransfer",
|
||||||
"Type": "Runtime",
|
"Type": "Runtime",
|
||||||
"LoadingPhase": "PostConfigInit",
|
"LoadingPhase": "Default",
|
||||||
"WhitelistPlatforms": [
|
"WhitelistPlatforms": [
|
||||||
"Win64"
|
"Win64"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue