Fixed the compute shaders for tensor conversion

Should in theory work now but it does not do anything
This commit is contained in:
Manuel Wagner 2022-08-26 17:33:40 +02:00
parent 7ec56a0475
commit e717b41489
10 changed files with 300 additions and 170 deletions

Binary file not shown.

Binary file not shown.

View File

@ -3,23 +3,21 @@
#include "/Engine/Public/Platform.ush" #include "/Engine/Public/Platform.ush"
// this assumes that the OutputTexture has
// the exact same dimensions as InputTensor!
uint TensorVolume;
uint2 TextureSize;
RWTexture2D<float4> OutputTexture; RWTexture2D<float4> OutputTexture;
Buffer<float> InputTensor; Buffer<float> InputTensor;
// DispatchThreadID corresponds to InputTensor shape dimensions not texture XY -> DispatchThreadID.X = Texture.Y
[numthreads(THREADGROUP_SIZE_X, THREADGROUP_SIZE_Y, THREADGROUP_SIZE_Z)] [numthreads(THREADGROUP_SIZE_X, THREADGROUP_SIZE_Y, THREADGROUP_SIZE_Z)]
void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID) void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{ {
uint TensorVolume;
InputTensor.GetDimensions(TensorVolume);
uint2 TextureSize = 0;
OutputTexture.GetDimensions(TextureSize.x, TextureSize.y);
// note that the input tensor has shape (1, Y, X, C) // note that the input tensor has shape (1, Y, X, C)
// which is why we need to flip the indexing // which is why we need to flip the indexing
const uint PixelIndex = DispatchThreadID.x * TextureSize.y + DispatchThreadID.x; const uint TensorPixelNumber = DispatchThreadID.x * TextureSize.x + DispatchThreadID.y;
const uint GlobalIndex = PixelIndex * 3; const uint GlobalIndex = TensorPixelNumber * 3;
if (GlobalIndex >= TensorVolume) if (GlobalIndex >= TensorVolume)
{ {
@ -27,10 +25,11 @@ void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThr
} }
const uint2 TextureCoords = uint2(DispatchThreadID.y, DispatchThreadID.x); const uint2 TextureCoords = uint2(DispatchThreadID.y, DispatchThreadID.x);
OutputTexture[TextureCoords] = float4( const float4 RGBAColor = float4(
InputTensor[GlobalIndex + 0], InputTensor[GlobalIndex + 0],
InputTensor[GlobalIndex + 1], InputTensor[GlobalIndex + 1],
InputTensor[GlobalIndex + 2], InputTensor[GlobalIndex + 2],
1.0f 0.0f
); );
OutputTexture[TextureCoords] = RGBAColor;
} }

View File

@ -12,16 +12,16 @@ float2 HalfPixelUV;
void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID) void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{ {
const uint2 OutputUAVTexelCoordinate = DispatchThreadID.xy; const uint2 OutputUAVTexelCoordinate = DispatchThreadID.xy;
if(any(OutputUAVTexelCoordinate > OutputDimensions)) if(any(OutputUAVTexelCoordinate >= OutputDimensions))
{ {
return; return;
} }
const uint GlobalIndex = (OutputUAVTexelCoordinate.x * OutputDimensions.x + OutputUAVTexelCoordinate.y) * 3; const uint GlobalIndex = ((OutputUAVTexelCoordinate.x * OutputDimensions.y + OutputUAVTexelCoordinate.y)*3);
// note that the input tensor has shape (1, Y, X, C) // note that the OutputUAV has shape (1, Y, X, C)
// which is why we need to flip the indexing // which is why we need to flip the indexing
const float2 UV = OutputUAVTexelCoordinate.yx / OutputDimensions.yx + HalfPixelUV; const float2 UV = float2(OutputUAVTexelCoordinate.yx) / float2(OutputDimensions.yx) + HalfPixelUV;
const float4 TextureValue = InputTexture.SampleLevel(InputTextureSampler, UV, 0); const float4 TextureValue = InputTexture.SampleLevel(InputTextureSampler, UV, 0);

View File

@ -18,7 +18,9 @@
#include "Containers/DynamicRHIResourceArray.h" #include "Containers/DynamicRHIResourceArray.h"
#include "PostProcess/PostProcessMaterial.h" #include "PostProcess/PostProcessMaterial.h"
#include "OutputTensorToSceneColorCS.h" #include "OutputTensorToSceneColorCS.h"
#include "PixelShaderUtils.h"
#include "SceneColorToInputTensorCS.h" #include "SceneColorToInputTensorCS.h"
#include "StyleTransferModule.h"
template <class OutType, class InType> template <class OutType, class InType>
OutType CastNarrowingSafe(InType InValue) OutType CastNarrowingSafe(InType InValue)
@ -35,8 +37,9 @@ OutType CastNarrowingSafe(InType InValue)
} }
FStyleTransferSceneViewExtension::FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, int32 InInferenceContext) FStyleTransferSceneViewExtension::FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, TSharedRef<int32> InInferenceContext)
: FSceneViewExtensionBase(AutoRegister) : FSceneViewExtensionBase(AutoRegister)
, StyleTransferNetworkWeakPtr(InStyleTransferNetwork)
, StyleTransferNetwork(InStyleTransferNetwork) , StyleTransferNetwork(InStyleTransferNetwork)
, LinkedViewportClient(AssociatedViewportClient) , LinkedViewportClient(AssociatedViewportClient)
, InferenceContext(InInferenceContext) , InferenceContext(InInferenceContext)
@ -48,9 +51,54 @@ void FStyleTransferSceneViewExtension::SetupViewFamily(FSceneViewFamily& InViewF
{ {
} }
void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProcessingPass PassId, bool FStyleTransferSceneViewExtension::IsActiveThisFrame_Internal(const FSceneViewExtensionContext& Context) const
FAfterPassCallbackDelegateArray& {
InOutPassCallbacks, bool bIsPassEnabled) check(IsInGameThread());
return bIsEnabled && *InferenceContext != -1 && StyleTransferNetworkWeakPtr.IsValid();
}
void FStyleTransferSceneViewExtension::AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget)
{
FGlobalShaderMap* ShaderMap = GetGlobalShaderMap(GMaxRHIFeatureLevel);
TShaderMapRef<FScreenPassVS> VertexShader(ShaderMap);
TShaderMapRef<FCopyRectPS> PixelShader(ShaderMap);
FCopyRectPS::FParameters* PixelShaderParameters = GraphBuilder.AllocParameters<FCopyRectPS::FParameters>();
PixelShaderParameters->InputTexture = &RDGSourceTexture;
PixelShaderParameters->InputSampler = TStaticSamplerState<SF_Bilinear, AM_Clamp, AM_Clamp, AM_Clamp>::GetRHI();
PixelShaderParameters->RenderTargets[0] = DestinationRenderTarget.GetRenderTargetBinding();
ClearUnusedGraphResources(PixelShader, PixelShaderParameters);
FRHIBlendState* BlendState = FScreenPassPipelineState::FDefaultBlendState::GetRHI();
FRHIDepthStencilState* DepthStencilState = FScreenPassPipelineState::FDefaultDepthStencilState::GetRHI();
const FScreenPassPipelineState PipelineState(VertexShader, PixelShader, BlendState, DepthStencilState);
GraphBuilder.AddPass(
RDG_EVENT_NAME("RescalingTextureCopy"),
PixelShaderParameters,
ERDGPassFlags::Raster,
[PipelineState, Extent = DestinationRenderTarget.Texture->Desc.Extent, PixelShader, PixelShaderParameters](FRHICommandList& RHICmdList)
{
PipelineState.Validate();
RHICmdList.SetViewport(0.0f, 0.0f, 0.0f, Extent.X, Extent.Y, 1.0f);
SetScreenPassPipelineState(RHICmdList, PipelineState);
SetShaderParameters(RHICmdList, PixelShader, PixelShader.GetPixelShader(), *PixelShaderParameters);
DrawRectangle(
RHICmdList,
0, 0, Extent.X, Extent.Y,
0, 0, Extent.X, Extent.Y,
Extent,
Extent,
PipelineState.VertexShader,
EDRF_UseTriangleOptimization);
});
}
void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProcessingPass PassId, FAfterPassCallbackDelegateArray& InOutPassCallbacks, bool bIsPassEnabled)
{ {
if (PassId == EPostProcessingPass::Tonemap) if (PassId == EPostProcessingPass::Tonemap)
{ {
@ -60,8 +108,88 @@ void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProces
} }
} }
FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap_RenderThread( FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor)
FRDGBuilder& GraphBuilder, const FSceneView& View, const FPostProcessMaterialInputs& InOutInputs) {
FIntVector SourceTensorDimensions = {
CastNarrowingSafe<int32>(SourceTensor.GetSize(1)),
CastNarrowingSafe<int32>(SourceTensor.GetSize(2)),
CastNarrowingSafe<int32>(SourceTensor.GetSize(3)),
};
// Reusing the same output description for our back buffer as SceneColor
FRDGTextureDesc DestinationDesc = BaseDestinationDesc;
// this is flipped because the Output tensor has the vertical dimension first
// while unreal has the horizontal dimension first
DestinationDesc.Extent = {SourceTensorDimensions[1], SourceTensorDimensions[0]};
DestinationDesc.Flags |= TexCreate_RenderTargetable | TexCreate_UAV;
FLinearColor ClearColor(0., 0., 0., 0.);
DestinationDesc.ClearValue = FClearValueBinding(ClearColor);
FRDGTexture* OutputTexture = GraphBuilder.CreateTexture(
DestinationDesc, TEXT("OutputTexture"));
FRDGBufferRef SourceTensorBuffer = GraphBuilder.RegisterExternalBuffer(SourceTensor.GetPooledBuffer());
auto OutputTensorToSceneColorParameters = GraphBuilder.AllocParameters<FOutputTensorToSceneColorCS::FParameters>();
OutputTensorToSceneColorParameters->InputTensor = GraphBuilder.CreateSRV(SourceTensorBuffer, EPixelFormat::PF_R32_FLOAT);
OutputTensorToSceneColorParameters->OutputTexture = GraphBuilder.CreateUAV(OutputTexture);
OutputTensorToSceneColorParameters->TensorVolume = SourceTensor.Num();
OutputTensorToSceneColorParameters->TextureSize = DestinationDesc.Extent;
FIntVector OutputTensorToSceneColorGroupCount = FComputeShaderUtils::GetGroupCount(
{SourceTensorDimensions.X, SourceTensorDimensions.Y, 1},
FOutputTensorToSceneColorCS::ThreadGroupSize
);
TShaderMapRef<FOutputTensorToSceneColorCS> OutputTensorToSceneColorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("TensorToTexture"),
OutputTensorToSceneColorParameters,
ERDGPassFlags::Compute,
[OutputTensorToSceneColorCS, OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, OutputTensorToSceneColorCS,
*OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount);
}
);
return OutputTexture;
}
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor)
{
const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
};
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture.Texture->Desc.Extent;
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
SceneColorToInputTensorParameters->InputTexture = SourceTexture.Texture;
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
SceneColorToInputTensorParameters->HalfPixelUV = FVector2f(0.5f / SceneColorRenderTargetDimensions.X, 0.5 / SceneColorRenderTargetDimensions.Y);
FIntVector SceneColorToInputTensorGroupCount = FComputeShaderUtils::GetGroupCount(
{InputTensorDimensions.X, InputTensorDimensions.Y, 1},
FSceneColorToInputTensorCS::ThreadGroupSize
);
TShaderMapRef<FSceneColorToInputTensorCS> SceneColorToInputTensorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("TextureToTensor"),
SceneColorToInputTensorParameters,
ERDGPassFlags::Compute,
[SceneColorToInputTensorCS, SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, SceneColorToInputTensorCS,
*SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount);
}
);
}
FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap_RenderThread(FRDGBuilder& GraphBuilder, const FSceneView& View, const FPostProcessMaterialInputs& InOutInputs)
{ {
const FSceneViewFamily& ViewFamily = *View.Family; const FSceneViewFamily& ViewFamily = *View.Family;
@ -79,103 +207,29 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
RDG_EVENT_SCOPE(GraphBuilder, "StyleTransfer"); RDG_EVENT_SCOPE(GraphBuilder, "StyleTransfer");
//Get input and output viewports. Backbuffer could be targeting a different region than input viewport
const FScreenPassTextureViewport SceneColorViewport(SceneColor);
FScreenPassRenderTarget SceneColorRenderTarget(SceneColor, ERenderTargetLoadAction::ELoad);
checkSlow(View.bIsViewInfo); checkSlow(View.bIsViewInfo);
const FViewInfo& ViewInfo = static_cast<const FViewInfo&>(View); const FViewInfo& ViewInfo = static_cast<const FViewInfo&>(View);
/*AddDrawScreenPass(GraphBuilder, RDG_EVENT_NAME("ProcessOCIOColorSpaceXfrm"), ViewInfo, BackBufferViewport, const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
SceneColorViewport, OCIOPixelShader, Parameters);*/
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(InferenceContext, 0); TextureToTensor(GraphBuilder, SceneColor, StyleTransferContentInputTensor);
const FIntVector InputTensorDimensions = { StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);
CastNarrowingSafe<int32>(StyleTransferContentInputTensor.GetSize(1)),
CastNarrowingSafe<int32>(StyleTransferContentInputTensor.GetSize(2)),
CastNarrowingSafe<int32>(StyleTransferContentInputTensor.GetSize(3)),
};
const FIntPoint SceneColorRenderTargetDimensions = SceneColorRenderTarget.Texture->Desc.Extent;
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(StyleTransferContentInputTensor.GetPooledBuffer()); const FNeuralTensor& StyleTransferContentOutputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>(); FRDGTexture* StyleTransferRenderTargetTexture = TensorToTexture(GraphBuilder, SceneColor.Texture->Desc, StyleTransferContentOutputTensor);
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(StyleTransferContentInputTensor.Num());
SceneColorToInputTensorParameters->InputTexture = SceneColorRenderTarget.Texture;
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
SceneColorToInputTensorParameters->HalfPixelUV = FVector2f(0.5f / SceneColorRenderTargetDimensions.X, 0.5 / SceneColorRenderTargetDimensions.Y);
FIntVector SceneColorToInputTensorGroupCount = FComputeShaderUtils::GetGroupCount(
{InputTensorDimensions.X, InputTensorDimensions.Y, 1},
FSceneColorToInputTensorCS::ThreadGroupSize
);
TShaderMapRef<FSceneColorToInputTensorCS> SceneColorToInputTensorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("SceneColorToInputTensor"),
SceneColorToInputTensorParameters,
ERDGPassFlags::Compute,
[SceneColorToInputTensorCS, SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, SceneColorToInputTensorCS,
*SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount);
}
);
const FNeuralTensor& StyleTransferOutputTensor = StyleTransferNetwork->GetOutputTensorForContext(InferenceContext, 0);
FIntVector OutputTensorDimensions = {
CastNarrowingSafe<int32>(StyleTransferOutputTensor.GetSize(1)),
CastNarrowingSafe<int32>(StyleTransferOutputTensor.GetSize(2)),
CastNarrowingSafe<int32>(StyleTransferOutputTensor.GetSize(3)),
};
// Reusing the same output description for our back buffer as SceneColor
FRDGTextureDesc OutputDesc = SceneColor.Texture->Desc;
// this is flipped because the Output tensor has the vertical dimension first
// while unreal has the horizontal dimension first
OutputDesc.Extent = {OutputTensorDimensions[1], OutputTensorDimensions[0]};
OutputDesc.Flags |= TexCreate_RenderTargetable | TexCreate_UAV;
FLinearColor ClearColor(0., 0., 0., 0.);
OutputDesc.ClearValue = FClearValueBinding(ClearColor);
FRDGTexture* StyleTransferRenderTargetTexture = GraphBuilder.CreateTexture(
OutputDesc, TEXT("StyleTransferRenderTargetTexture"));
TSharedPtr<FScreenPassRenderTarget> StyleTransferOutputTarget = MakeShared<FScreenPassRenderTarget>(StyleTransferRenderTargetTexture, SceneColor.ViewRect, TSharedPtr<FScreenPassRenderTarget> StyleTransferOutputTarget = MakeShared<FScreenPassRenderTarget>(StyleTransferRenderTargetTexture, SceneColor.ViewRect,
ERenderTargetLoadAction::EClear); ERenderTargetLoadAction::EClear);
StyleTransferNetwork->Run(GraphBuilder, InferenceContext);
FRDGBufferRef StyleTransferOutputBuffer = GraphBuilder.RegisterExternalBuffer(StyleTransferContentInputTensor.GetPooledBuffer());
auto OutputTensorToSceneColorParameters = GraphBuilder.AllocParameters<FOutputTensorToSceneColorCS::FParameters>();
OutputTensorToSceneColorParameters->InputTensor = GraphBuilder.CreateSRV(StyleTransferOutputBuffer, EPixelFormat::PF_FloatRGB);
OutputTensorToSceneColorParameters->OutputTexture = GraphBuilder.CreateUAV(StyleTransferRenderTargetTexture);
FIntVector OutputTensorToSceneColorGroupCount = FComputeShaderUtils::GetGroupCount(
{OutputTensorDimensions.X, OutputTensorDimensions.Y, 1},
FOutputTensorToSceneColorCS::ThreadGroupSize
);
TShaderMapRef<FOutputTensorToSceneColorCS> OutputTensorToSceneColorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("OutputTensorToSceneColor"),
OutputTensorToSceneColorParameters,
ERDGPassFlags::Compute,
[OutputTensorToSceneColorCS, OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, OutputTensorToSceneColorCS,
*OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount);
}
);
TSharedPtr<FScreenPassRenderTarget> BackBufferRenderTarget; TSharedPtr<FScreenPassRenderTarget> BackBufferRenderTarget;
// If the override output is provided it means that this is the last pass in post processing. // If the override output is provided it means that this is the last pass in post processing.
if (InOutInputs.OverrideOutput.IsValid()) if (InOutInputs.OverrideOutput.IsValid())
{ {
BackBufferRenderTarget = MakeShared<FScreenPassRenderTarget>(InOutInputs.OverrideOutput); BackBufferRenderTarget = MakeShared<FScreenPassRenderTarget>(InOutInputs.OverrideOutput);
// @todo: do not use copy. Resample the styled 1920x960 texture to the fullscreen texture by drawing into the texture
AddCopyTexturePass(GraphBuilder, StyleTransferRenderTargetTexture, BackBufferRenderTarget->Texture); AddRescalingTextureCopy(GraphBuilder, *StyleTransferOutputTarget->Texture, *BackBufferRenderTarget);
} }
else else
{ {

View File

@ -5,12 +5,116 @@
#include "NeuralNetwork.h" #include "NeuralNetwork.h"
#include "RenderGraphUtils.h" #include "RenderGraphUtils.h"
#include "StyleTransferModule.h"
#include "StyleTransferSceneViewExtension.h" #include "StyleTransferSceneViewExtension.h"
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
TEXT("r.StyleTransfer.Enabled"),
false,
TEXT("Set to true to enable style transfer")
);
void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection) void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
{ {
Super::Initialize(Collection); Super::Initialize(Collection);
CVarStyleTransferEnabled->OnChangedDelegate().AddUObject(this, &UStyleTransferSubsystem::HandleConsoleVariableChanged);
}
void UStyleTransferSubsystem::Deinitialize()
{
StopStylizingViewport();
Super::Deinitialize();
}
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
{
if (!StylePredictionNetwork->IsLoaded() || !StyleTransferNetwork->IsLoaded())
{
UE_LOG(LogStyleTransfer, Error, TEXT("Not all networks were loaded, can not stylize viewport."));
return;
}
if (!StyleTransferSceneViewExtension)
{
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork"));
StyleTransferInferenceContext = MakeShared<int32>(StyleTransferNetwork->CreateInferenceContext());
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
FlushRenderingCommands();
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
}
StyleTransferSceneViewExtension->SetEnabled(true);
}
void UStyleTransferSubsystem::StopStylizingViewport()
{
FlushRenderingCommands();
StyleTransferSceneViewExtension.Reset();
if (StylePredictionInferenceContext != INDEX_NONE)
{
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
}
if (StyleTransferInferenceContext && *StyleTransferInferenceContext != INDEX_NONE)
{
StyleTransferNetwork->DestroyInferenceContext(*StyleTransferInferenceContext);
*StyleTransferInferenceContext = INDEX_NONE;
StyleTransferInferenceContext.Reset();
}
}
void UStyleTransferSubsystem::UpdateStyle(const FNeuralTensor& StyleImage)
{
checkf(StyleTransferSceneViewExtension.IsValid(), TEXT("Can not update style while not stylizing"));
checkf(StyleTransferInferenceContext.IsValid(), TEXT("Can not update style without inference context"));
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
const uint64 NumBytes = OutputStyleParams.NumInBytes();
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
GraphBuilder.Execute();
});
FlushRenderingCommands();
}
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
{
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());
StopStylizingViewport();
if (CVarStyleTransferEnabled->GetBool())
{
if(!(StyleTransferNetwork || StylePredictionNetwork))
{
LoadNetworks();
}
StartStylizingViewport(GetGameInstance()->GetGameViewportClient());
}
}
void UStyleTransferSubsystem::LoadNetworks()
{
StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer")); StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer"));
StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor")); StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor"));
@ -29,7 +133,7 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
} }
else else
{ {
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not be loaded")); UE_LOG(LogStyleTransfer, Error, TEXT("StyleTransferNetwork could not be loaded"));
} }
@ -39,61 +143,6 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
} }
else else
{ {
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not be loaded.")); UE_LOG(LogStyleTransfer, Error, TEXT("StylePredictionNetwork could not be loaded."));
} }
} }
void UStyleTransferSubsystem::Deinitialize()
{
StyleTransferSceneViewExtension.Reset();
if(StylePredictionInferenceContext != INDEX_NONE)
{
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
}
if(StyleTransferInferenceContext != INDEX_NONE)
{
StyleTransferNetwork->DestroyInferenceContext(StyleTransferInferenceContext);
}
Super::Deinitialize();
}
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
{
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork"));
StyleTransferInferenceContext = StyleTransferNetwork->CreateInferenceContext();
checkf(StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext);
}
void UStyleTransferSubsystem::UpdateStyle(const FNeuralTensor& StyleImage)
{
checkf(StyleTransferSceneViewExtension.IsValid(), TEXT("Can not update style while not stylizing"));
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
const uint64 NumBytes = OutputStyleParams.NumInBytes();
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
GraphBuilder.Execute();
});
FlushRenderingCommands();
}

View File

@ -1,6 +1,8 @@
#pragma once #pragma once
#include "SceneViewExtension.h" #include "SceneViewExtension.h"
struct FNeuralTensor;
struct FScreenPassRenderTarget;
class UNeuralNetwork; class UNeuralNetwork;
class FStyleTransferSceneViewExtension : public FSceneViewExtensionBase class FStyleTransferSceneViewExtension : public FSceneViewExtensionBase
@ -9,21 +11,40 @@ public:
using Ptr = TSharedPtr<FStyleTransferSceneViewExtension, ESPMode::ThreadSafe>; using Ptr = TSharedPtr<FStyleTransferSceneViewExtension, ESPMode::ThreadSafe>;
using Ref = TSharedRef<FStyleTransferSceneViewExtension, ESPMode::ThreadSafe>; using Ref = TSharedRef<FStyleTransferSceneViewExtension, ESPMode::ThreadSafe>;
FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, int32 InInferenceContext); FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, TSharedRef<int32> InInferenceContext);
// - ISceneViewExtension // - ISceneViewExtension
virtual void SubscribeToPostProcessingPass(EPostProcessingPass Pass, FAfterPassCallbackDelegateArray& InOutPassCallbacks, bool bIsPassEnabled) override; virtual void SubscribeToPostProcessingPass(EPostProcessingPass Pass, FAfterPassCallbackDelegateArray& InOutPassCallbacks, bool bIsPassEnabled) override;
FScreenPassTexture PostProcessPassAfterTonemap_RenderThread(FRDGBuilder& GraphBuilder, const FSceneView& View, FScreenPassTexture PostProcessPassAfterTonemap_RenderThread(FRDGBuilder& GraphBuilder, const FSceneView& View,
const FPostProcessMaterialInputs& InOutInputs); const FPostProcessMaterialInputs& InOutInputs);
virtual void SetupViewFamily(FSceneViewFamily& InViewFamily) override; virtual void SetupViewFamily(FSceneViewFamily& InViewFamily) override;
virtual void SetupView(FSceneViewFamily& InViewFamily, FSceneView& InView) override {};
virtual void BeginRenderViewFamily(FSceneViewFamily& InViewFamily) override {}; virtual void SetupView(FSceneViewFamily& InViewFamily, FSceneView& InView) override
{
}
virtual void BeginRenderViewFamily(FSceneViewFamily& InViewFamily) override
{
}
virtual bool IsActiveThisFrame_Internal(const FSceneViewExtensionContext& Context) const override;
// -- // --
void SetEnabled(bool bInIsEnabled) { bIsEnabled = bInIsEnabled; }
bool IsEnabled() const { return bIsEnabled; }
private: private:
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */
TWeakObjectPtr<UNeuralNetwork> StyleTransferNetworkWeakPtr;
TObjectPtr<UNeuralNetwork> StyleTransferNetwork; TObjectPtr<UNeuralNetwork> StyleTransferNetwork;
FViewportClient* LinkedViewportClient; FViewportClient* LinkedViewportClient;
int32 InferenceContext = -1; TSharedRef<int32, ESPMode::ThreadSafe> InferenceContext = MakeShared<int32>(-1);
bool bIsEnabled = true;
void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
void TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor);
}; };

View File

@ -24,6 +24,7 @@ public:
// -- // --
void StartStylizingViewport(FViewportClient* ViewportClient); void StartStylizingViewport(FViewportClient* ViewportClient);
void StopStylizingViewport();
void UpdateStyle(const FNeuralTensor& StyleImage); void UpdateStyle(const FNeuralTensor& StyleImage);
@ -37,7 +38,12 @@ private:
TObjectPtr<UNeuralNetwork> StylePredictionNetwork; TObjectPtr<UNeuralNetwork> StylePredictionNetwork;
int32 StylePredictionInferenceContext = INDEX_NONE; int32 StylePredictionInferenceContext = INDEX_NONE;
int32 StyleTransferInferenceContext = INDEX_NONE; TSharedPtr<int32, ESPMode::ThreadSafe> StyleTransferInferenceContext;
int32 StyleTransferStyleParamsInputIndex = INDEX_NONE; int32 StyleTransferStyleParamsInputIndex = INDEX_NONE;
void HandleConsoleVariableChanged(IConsoleVariable*);
void LoadNetworks();
}; };

View File

@ -20,6 +20,7 @@ class STYLETRANSFERSHADERS_API FOutputTensorToSceneColorCS : public FGlobalShade
BEGIN_SHADER_PARAMETER_STRUCT(FParameters, ) BEGIN_SHADER_PARAMETER_STRUCT(FParameters, )
SHADER_PARAMETER(uint32, TensorVolume) SHADER_PARAMETER(uint32, TensorVolume)
SHADER_PARAMETER(FIntPoint, TextureSize)
SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer<float3>, InputTensor) SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer<float3>, InputTensor)
SHADER_PARAMETER_RDG_TEXTURE_UAV(RWTexture2D, OutputTexture) SHADER_PARAMETER_RDG_TEXTURE_UAV(RWTexture2D, OutputTexture)
END_SHADER_PARAMETER_STRUCT() END_SHADER_PARAMETER_STRUCT()