Fixed the compute shaders for tensor conversion

Should in theory work now but it does not do anything
This commit is contained in:
Manuel Wagner 2022-08-26 17:33:40 +02:00
parent 7ec56a0475
commit e717b41489
10 changed files with 300 additions and 170 deletions

Binary file not shown.

Binary file not shown.

View File

@ -3,23 +3,21 @@
#include "/Engine/Public/Platform.ush"
// this assumes that the OutputTexture has
// the exact same dimensions as InputTensor!
uint TensorVolume;
uint2 TextureSize;
RWTexture2D<float4> OutputTexture;
Buffer<float> InputTensor;
// DispatchThreadID corresponds to InputTensor shape dimensions not texture XY -> DispatchThreadID.X = Texture.Y
[numthreads(THREADGROUP_SIZE_X, THREADGROUP_SIZE_Y, THREADGROUP_SIZE_Z)]
void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{
uint TensorVolume;
InputTensor.GetDimensions(TensorVolume);
uint2 TextureSize = 0;
OutputTexture.GetDimensions(TextureSize.x, TextureSize.y);
// note that the input tensor has shape (1, Y, X, C)
// which is why we need to flip the indexing
const uint PixelIndex = DispatchThreadID.x * TextureSize.y + DispatchThreadID.x;
const uint GlobalIndex = PixelIndex * 3;
const uint TensorPixelNumber = DispatchThreadID.x * TextureSize.x + DispatchThreadID.y;
const uint GlobalIndex = TensorPixelNumber * 3;
if (GlobalIndex >= TensorVolume)
{
@ -27,10 +25,11 @@ void OutputTensorToSceneColorCS(in const uint3 DispatchThreadID : SV_DispatchThr
}
const uint2 TextureCoords = uint2(DispatchThreadID.y, DispatchThreadID.x);
OutputTexture[TextureCoords] = float4(
const float4 RGBAColor = float4(
InputTensor[GlobalIndex + 0],
InputTensor[GlobalIndex + 1],
InputTensor[GlobalIndex + 2],
1.0f
0.0f
);
OutputTexture[TextureCoords] = RGBAColor;
}

View File

@ -12,16 +12,16 @@ float2 HalfPixelUV;
void SceneColorToInputTensorCS(in const uint3 DispatchThreadID : SV_DispatchThreadID)
{
const uint2 OutputUAVTexelCoordinate = DispatchThreadID.xy;
if(any(OutputUAVTexelCoordinate > OutputDimensions))
if(any(OutputUAVTexelCoordinate >= OutputDimensions))
{
return;
}
const uint GlobalIndex = (OutputUAVTexelCoordinate.x * OutputDimensions.x + OutputUAVTexelCoordinate.y) * 3;
const uint GlobalIndex = ((OutputUAVTexelCoordinate.x * OutputDimensions.y + OutputUAVTexelCoordinate.y)*3);
// note that the input tensor has shape (1, Y, X, C)
// note that the OutputUAV has shape (1, Y, X, C)
// which is why we need to flip the indexing
const float2 UV = OutputUAVTexelCoordinate.yx / OutputDimensions.yx + HalfPixelUV;
const float2 UV = float2(OutputUAVTexelCoordinate.yx) / float2(OutputDimensions.yx) + HalfPixelUV;
const float4 TextureValue = InputTexture.SampleLevel(InputTextureSampler, UV, 0);

View File

@ -18,7 +18,9 @@
#include "Containers/DynamicRHIResourceArray.h"
#include "PostProcess/PostProcessMaterial.h"
#include "OutputTensorToSceneColorCS.h"
#include "PixelShaderUtils.h"
#include "SceneColorToInputTensorCS.h"
#include "StyleTransferModule.h"
template <class OutType, class InType>
OutType CastNarrowingSafe(InType InValue)
@ -35,8 +37,9 @@ OutType CastNarrowingSafe(InType InValue)
}
FStyleTransferSceneViewExtension::FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, int32 InInferenceContext)
FStyleTransferSceneViewExtension::FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, TSharedRef<int32> InInferenceContext)
: FSceneViewExtensionBase(AutoRegister)
, StyleTransferNetworkWeakPtr(InStyleTransferNetwork)
, StyleTransferNetwork(InStyleTransferNetwork)
, LinkedViewportClient(AssociatedViewportClient)
, InferenceContext(InInferenceContext)
@ -48,9 +51,54 @@ void FStyleTransferSceneViewExtension::SetupViewFamily(FSceneViewFamily& InViewF
{
}
void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProcessingPass PassId,
FAfterPassCallbackDelegateArray&
InOutPassCallbacks, bool bIsPassEnabled)
bool FStyleTransferSceneViewExtension::IsActiveThisFrame_Internal(const FSceneViewExtensionContext& Context) const
{
check(IsInGameThread());
return bIsEnabled && *InferenceContext != -1 && StyleTransferNetworkWeakPtr.IsValid();
}
void FStyleTransferSceneViewExtension::AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget)
{
FGlobalShaderMap* ShaderMap = GetGlobalShaderMap(GMaxRHIFeatureLevel);
TShaderMapRef<FScreenPassVS> VertexShader(ShaderMap);
TShaderMapRef<FCopyRectPS> PixelShader(ShaderMap);
FCopyRectPS::FParameters* PixelShaderParameters = GraphBuilder.AllocParameters<FCopyRectPS::FParameters>();
PixelShaderParameters->InputTexture = &RDGSourceTexture;
PixelShaderParameters->InputSampler = TStaticSamplerState<SF_Bilinear, AM_Clamp, AM_Clamp, AM_Clamp>::GetRHI();
PixelShaderParameters->RenderTargets[0] = DestinationRenderTarget.GetRenderTargetBinding();
ClearUnusedGraphResources(PixelShader, PixelShaderParameters);
FRHIBlendState* BlendState = FScreenPassPipelineState::FDefaultBlendState::GetRHI();
FRHIDepthStencilState* DepthStencilState = FScreenPassPipelineState::FDefaultDepthStencilState::GetRHI();
const FScreenPassPipelineState PipelineState(VertexShader, PixelShader, BlendState, DepthStencilState);
GraphBuilder.AddPass(
RDG_EVENT_NAME("RescalingTextureCopy"),
PixelShaderParameters,
ERDGPassFlags::Raster,
[PipelineState, Extent = DestinationRenderTarget.Texture->Desc.Extent, PixelShader, PixelShaderParameters](FRHICommandList& RHICmdList)
{
PipelineState.Validate();
RHICmdList.SetViewport(0.0f, 0.0f, 0.0f, Extent.X, Extent.Y, 1.0f);
SetScreenPassPipelineState(RHICmdList, PipelineState);
SetShaderParameters(RHICmdList, PixelShader, PixelShader.GetPixelShader(), *PixelShaderParameters);
DrawRectangle(
RHICmdList,
0, 0, Extent.X, Extent.Y,
0, 0, Extent.X, Extent.Y,
Extent,
Extent,
PipelineState.VertexShader,
EDRF_UseTriangleOptimization);
});
}
void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProcessingPass PassId, FAfterPassCallbackDelegateArray& InOutPassCallbacks, bool bIsPassEnabled)
{
if (PassId == EPostProcessingPass::Tonemap)
{
@ -60,8 +108,88 @@ void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProces
}
}
FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap_RenderThread(
FRDGBuilder& GraphBuilder, const FSceneView& View, const FPostProcessMaterialInputs& InOutInputs)
FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor)
{
FIntVector SourceTensorDimensions = {
CastNarrowingSafe<int32>(SourceTensor.GetSize(1)),
CastNarrowingSafe<int32>(SourceTensor.GetSize(2)),
CastNarrowingSafe<int32>(SourceTensor.GetSize(3)),
};
// Reusing the same output description for our back buffer as SceneColor
FRDGTextureDesc DestinationDesc = BaseDestinationDesc;
// this is flipped because the Output tensor has the vertical dimension first
// while unreal has the horizontal dimension first
DestinationDesc.Extent = {SourceTensorDimensions[1], SourceTensorDimensions[0]};
DestinationDesc.Flags |= TexCreate_RenderTargetable | TexCreate_UAV;
FLinearColor ClearColor(0., 0., 0., 0.);
DestinationDesc.ClearValue = FClearValueBinding(ClearColor);
FRDGTexture* OutputTexture = GraphBuilder.CreateTexture(
DestinationDesc, TEXT("OutputTexture"));
FRDGBufferRef SourceTensorBuffer = GraphBuilder.RegisterExternalBuffer(SourceTensor.GetPooledBuffer());
auto OutputTensorToSceneColorParameters = GraphBuilder.AllocParameters<FOutputTensorToSceneColorCS::FParameters>();
OutputTensorToSceneColorParameters->InputTensor = GraphBuilder.CreateSRV(SourceTensorBuffer, EPixelFormat::PF_R32_FLOAT);
OutputTensorToSceneColorParameters->OutputTexture = GraphBuilder.CreateUAV(OutputTexture);
OutputTensorToSceneColorParameters->TensorVolume = SourceTensor.Num();
OutputTensorToSceneColorParameters->TextureSize = DestinationDesc.Extent;
FIntVector OutputTensorToSceneColorGroupCount = FComputeShaderUtils::GetGroupCount(
{SourceTensorDimensions.X, SourceTensorDimensions.Y, 1},
FOutputTensorToSceneColorCS::ThreadGroupSize
);
TShaderMapRef<FOutputTensorToSceneColorCS> OutputTensorToSceneColorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("TensorToTexture"),
OutputTensorToSceneColorParameters,
ERDGPassFlags::Compute,
[OutputTensorToSceneColorCS, OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, OutputTensorToSceneColorCS,
*OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount);
}
);
return OutputTexture;
}
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor)
{
const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
};
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture.Texture->Desc.Extent;
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
SceneColorToInputTensorParameters->InputTexture = SourceTexture.Texture;
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
SceneColorToInputTensorParameters->HalfPixelUV = FVector2f(0.5f / SceneColorRenderTargetDimensions.X, 0.5 / SceneColorRenderTargetDimensions.Y);
FIntVector SceneColorToInputTensorGroupCount = FComputeShaderUtils::GetGroupCount(
{InputTensorDimensions.X, InputTensorDimensions.Y, 1},
FSceneColorToInputTensorCS::ThreadGroupSize
);
TShaderMapRef<FSceneColorToInputTensorCS> SceneColorToInputTensorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("TextureToTensor"),
SceneColorToInputTensorParameters,
ERDGPassFlags::Compute,
[SceneColorToInputTensorCS, SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, SceneColorToInputTensorCS,
*SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount);
}
);
}
FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap_RenderThread(FRDGBuilder& GraphBuilder, const FSceneView& View, const FPostProcessMaterialInputs& InOutInputs)
{
const FSceneViewFamily& ViewFamily = *View.Family;
@ -79,103 +207,29 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
RDG_EVENT_SCOPE(GraphBuilder, "StyleTransfer");
//Get input and output viewports. Backbuffer could be targeting a different region than input viewport
const FScreenPassTextureViewport SceneColorViewport(SceneColor);
FScreenPassRenderTarget SceneColorRenderTarget(SceneColor, ERenderTargetLoadAction::ELoad);
checkSlow(View.bIsViewInfo);
const FViewInfo& ViewInfo = static_cast<const FViewInfo&>(View);
/*AddDrawScreenPass(GraphBuilder, RDG_EVENT_NAME("ProcessOCIOColorSpaceXfrm"), ViewInfo, BackBufferViewport,
SceneColorViewport, OCIOPixelShader, Parameters);*/
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(InferenceContext, 0);
TextureToTensor(GraphBuilder, SceneColor, StyleTransferContentInputTensor);
const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(StyleTransferContentInputTensor.GetSize(1)),
CastNarrowingSafe<int32>(StyleTransferContentInputTensor.GetSize(2)),
CastNarrowingSafe<int32>(StyleTransferContentInputTensor.GetSize(3)),
};
const FIntPoint SceneColorRenderTargetDimensions = SceneColorRenderTarget.Texture->Desc.Extent;
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(StyleTransferContentInputTensor.GetPooledBuffer());
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(StyleTransferContentInputTensor.Num());
SceneColorToInputTensorParameters->InputTexture = SceneColorRenderTarget.Texture;
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
SceneColorToInputTensorParameters->HalfPixelUV = FVector2f(0.5f / SceneColorRenderTargetDimensions.X, 0.5 / SceneColorRenderTargetDimensions.Y);
FIntVector SceneColorToInputTensorGroupCount = FComputeShaderUtils::GetGroupCount(
{InputTensorDimensions.X, InputTensorDimensions.Y, 1},
FSceneColorToInputTensorCS::ThreadGroupSize
);
const FNeuralTensor& StyleTransferContentOutputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
FRDGTexture* StyleTransferRenderTargetTexture = TensorToTexture(GraphBuilder, SceneColor.Texture->Desc, StyleTransferContentOutputTensor);
TShaderMapRef<FSceneColorToInputTensorCS> SceneColorToInputTensorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("SceneColorToInputTensor"),
SceneColorToInputTensorParameters,
ERDGPassFlags::Compute,
[SceneColorToInputTensorCS, SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, SceneColorToInputTensorCS,
*SceneColorToInputTensorParameters, SceneColorToInputTensorGroupCount);
}
);
const FNeuralTensor& StyleTransferOutputTensor = StyleTransferNetwork->GetOutputTensorForContext(InferenceContext, 0);
FIntVector OutputTensorDimensions = {
CastNarrowingSafe<int32>(StyleTransferOutputTensor.GetSize(1)),
CastNarrowingSafe<int32>(StyleTransferOutputTensor.GetSize(2)),
CastNarrowingSafe<int32>(StyleTransferOutputTensor.GetSize(3)),
};
// Reusing the same output description for our back buffer as SceneColor
FRDGTextureDesc OutputDesc = SceneColor.Texture->Desc;
// this is flipped because the Output tensor has the vertical dimension first
// while unreal has the horizontal dimension first
OutputDesc.Extent = {OutputTensorDimensions[1], OutputTensorDimensions[0]};
OutputDesc.Flags |= TexCreate_RenderTargetable | TexCreate_UAV;
FLinearColor ClearColor(0., 0., 0., 0.);
OutputDesc.ClearValue = FClearValueBinding(ClearColor);
FRDGTexture* StyleTransferRenderTargetTexture = GraphBuilder.CreateTexture(
OutputDesc, TEXT("StyleTransferRenderTargetTexture"));
TSharedPtr<FScreenPassRenderTarget> StyleTransferOutputTarget = MakeShared<FScreenPassRenderTarget>(StyleTransferRenderTargetTexture, SceneColor.ViewRect,
ERenderTargetLoadAction::EClear);
StyleTransferNetwork->Run(GraphBuilder, InferenceContext);
FRDGBufferRef StyleTransferOutputBuffer = GraphBuilder.RegisterExternalBuffer(StyleTransferContentInputTensor.GetPooledBuffer());
auto OutputTensorToSceneColorParameters = GraphBuilder.AllocParameters<FOutputTensorToSceneColorCS::FParameters>();
OutputTensorToSceneColorParameters->InputTensor = GraphBuilder.CreateSRV(StyleTransferOutputBuffer, EPixelFormat::PF_FloatRGB);
OutputTensorToSceneColorParameters->OutputTexture = GraphBuilder.CreateUAV(StyleTransferRenderTargetTexture);
FIntVector OutputTensorToSceneColorGroupCount = FComputeShaderUtils::GetGroupCount(
{OutputTensorDimensions.X, OutputTensorDimensions.Y, 1},
FOutputTensorToSceneColorCS::ThreadGroupSize
);
TShaderMapRef<FOutputTensorToSceneColorCS> OutputTensorToSceneColorCS(GetGlobalShaderMap(GMaxRHIFeatureLevel));
GraphBuilder.AddPass(
RDG_EVENT_NAME("OutputTensorToSceneColor"),
OutputTensorToSceneColorParameters,
ERDGPassFlags::Compute,
[OutputTensorToSceneColorCS, OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount](FRHICommandList& RHICommandList)
{
FComputeShaderUtils::Dispatch(RHICommandList, OutputTensorToSceneColorCS,
*OutputTensorToSceneColorParameters, OutputTensorToSceneColorGroupCount);
}
);
TSharedPtr<FScreenPassRenderTarget> BackBufferRenderTarget;
// If the override output is provided it means that this is the last pass in post processing.
if (InOutInputs.OverrideOutput.IsValid())
{
BackBufferRenderTarget = MakeShared<FScreenPassRenderTarget>(InOutInputs.OverrideOutput);
// @todo: do not use copy. Resample the styled 1920x960 texture to the fullscreen texture by drawing into the texture
AddCopyTexturePass(GraphBuilder, StyleTransferRenderTargetTexture, BackBufferRenderTarget->Texture);
AddRescalingTextureCopy(GraphBuilder, *StyleTransferOutputTarget->Texture, *BackBufferRenderTarget);
}
else
{

View File

@ -5,12 +5,116 @@
#include "NeuralNetwork.h"
#include "RenderGraphUtils.h"
#include "StyleTransferModule.h"
#include "StyleTransferSceneViewExtension.h"
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
TEXT("r.StyleTransfer.Enabled"),
false,
TEXT("Set to true to enable style transfer")
);
void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
{
Super::Initialize(Collection);
CVarStyleTransferEnabled->OnChangedDelegate().AddUObject(this, &UStyleTransferSubsystem::HandleConsoleVariableChanged);
}
void UStyleTransferSubsystem::Deinitialize()
{
StopStylizingViewport();
Super::Deinitialize();
}
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
{
if (!StylePredictionNetwork->IsLoaded() || !StyleTransferNetwork->IsLoaded())
{
UE_LOG(LogStyleTransfer, Error, TEXT("Not all networks were loaded, can not stylize viewport."));
return;
}
if (!StyleTransferSceneViewExtension)
{
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork"));
StyleTransferInferenceContext = MakeShared<int32>(StyleTransferNetwork->CreateInferenceContext());
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
FlushRenderingCommands();
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
}
StyleTransferSceneViewExtension->SetEnabled(true);
}
void UStyleTransferSubsystem::StopStylizingViewport()
{
FlushRenderingCommands();
StyleTransferSceneViewExtension.Reset();
if (StylePredictionInferenceContext != INDEX_NONE)
{
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
}
if (StyleTransferInferenceContext && *StyleTransferInferenceContext != INDEX_NONE)
{
StyleTransferNetwork->DestroyInferenceContext(*StyleTransferInferenceContext);
*StyleTransferInferenceContext = INDEX_NONE;
StyleTransferInferenceContext.Reset();
}
}
void UStyleTransferSubsystem::UpdateStyle(const FNeuralTensor& StyleImage)
{
checkf(StyleTransferSceneViewExtension.IsValid(), TEXT("Can not update style while not stylizing"));
checkf(StyleTransferInferenceContext.IsValid(), TEXT("Can not update style without inference context"));
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
const uint64 NumBytes = OutputStyleParams.NumInBytes();
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
GraphBuilder.Execute();
});
FlushRenderingCommands();
}
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
{
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());
StopStylizingViewport();
if (CVarStyleTransferEnabled->GetBool())
{
if(!(StyleTransferNetwork || StylePredictionNetwork))
{
LoadNetworks();
}
StartStylizingViewport(GetGameInstance()->GetGameViewportClient());
}
}
void UStyleTransferSubsystem::LoadNetworks()
{
StyleTransferNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StyleTransfer.NN_StyleTransfer"));
StylePredictionNetwork = LoadObject<UNeuralNetwork>(this, TEXT("/StyleTransfer/NN_StylePredictor.NN_StylePredictor"));
@ -29,7 +133,7 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
}
else
{
UE_LOG(LogTemp, Warning, TEXT("StyleTransferNetwork could not be loaded"));
UE_LOG(LogStyleTransfer, Error, TEXT("StyleTransferNetwork could not be loaded"));
}
@ -39,61 +143,6 @@ void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
}
else
{
UE_LOG(LogTemp, Warning, TEXT("StylePredictionNetwork could not be loaded."));
UE_LOG(LogStyleTransfer, Error, TEXT("StylePredictionNetwork could not be loaded."));
}
}
void UStyleTransferSubsystem::Deinitialize()
{
StyleTransferSceneViewExtension.Reset();
if(StylePredictionInferenceContext != INDEX_NONE)
{
StylePredictionNetwork->DestroyInferenceContext(StylePredictionInferenceContext);
}
if(StyleTransferInferenceContext != INDEX_NONE)
{
StyleTransferNetwork->DestroyInferenceContext(StyleTransferInferenceContext);
}
Super::Deinitialize();
}
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
{
StylePredictionInferenceContext = StylePredictionNetwork->CreateInferenceContext();
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork"));
StyleTransferInferenceContext = StyleTransferNetwork->CreateInferenceContext();
checkf(StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext);
}
void UStyleTransferSubsystem::UpdateStyle(const FNeuralTensor& StyleImage)
{
checkf(StyleTransferSceneViewExtension.IsValid(), TEXT("Can not update style while not stylizing"));
StylePredictionNetwork->SetInputFromArrayCopy(StyleImage.GetArrayCopy<float>());
ENQUEUE_RENDER_COMMAND(StylePrediction)([this](FRHICommandListImmediate& RHICommandList)
{
FRDGBuilder GraphBuilder(RHICommandList);
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
const uint64 NumBytes = OutputStyleParams.NumInBytes();
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
GraphBuilder.Execute();
});
FlushRenderingCommands();
}

View File

@ -1,6 +1,8 @@
#pragma once
#include "SceneViewExtension.h"
struct FNeuralTensor;
struct FScreenPassRenderTarget;
class UNeuralNetwork;
class FStyleTransferSceneViewExtension : public FSceneViewExtensionBase
@ -9,21 +11,40 @@ public:
using Ptr = TSharedPtr<FStyleTransferSceneViewExtension, ESPMode::ThreadSafe>;
using Ref = TSharedRef<FStyleTransferSceneViewExtension, ESPMode::ThreadSafe>;
FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, int32 InInferenceContext);
FStyleTransferSceneViewExtension(const FAutoRegister& AutoRegister, FViewportClient* AssociatedViewportClient, UNeuralNetwork* InStyleTransferNetwork, TSharedRef<int32> InInferenceContext);
// - ISceneViewExtension
virtual void SubscribeToPostProcessingPass(EPostProcessingPass Pass, FAfterPassCallbackDelegateArray& InOutPassCallbacks, bool bIsPassEnabled) override;
FScreenPassTexture PostProcessPassAfterTonemap_RenderThread(FRDGBuilder& GraphBuilder, const FSceneView& View,
const FPostProcessMaterialInputs& InOutInputs);
virtual void SetupViewFamily(FSceneViewFamily& InViewFamily) override;
virtual void SetupView(FSceneViewFamily& InViewFamily, FSceneView& InView) override {};
virtual void BeginRenderViewFamily(FSceneViewFamily& InViewFamily) override {};
virtual void SetupView(FSceneViewFamily& InViewFamily, FSceneView& InView) override
{
}
virtual void BeginRenderViewFamily(FSceneViewFamily& InViewFamily) override
{
}
virtual bool IsActiveThisFrame_Internal(const FSceneViewExtensionContext& Context) const override;
// --
void SetEnabled(bool bInIsEnabled) { bIsEnabled = bInIsEnabled; }
bool IsEnabled() const { return bIsEnabled; }
private:
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */
TWeakObjectPtr<UNeuralNetwork> StyleTransferNetworkWeakPtr;
TObjectPtr<UNeuralNetwork> StyleTransferNetwork;
FViewportClient* LinkedViewportClient;
int32 InferenceContext = -1;
TSharedRef<int32, ESPMode::ThreadSafe> InferenceContext = MakeShared<int32>(-1);
bool bIsEnabled = true;
void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
void TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor);
};

View File

@ -24,6 +24,7 @@ public:
// --
void StartStylizingViewport(FViewportClient* ViewportClient);
void StopStylizingViewport();
void UpdateStyle(const FNeuralTensor& StyleImage);
@ -37,7 +38,12 @@ private:
TObjectPtr<UNeuralNetwork> StylePredictionNetwork;
int32 StylePredictionInferenceContext = INDEX_NONE;
int32 StyleTransferInferenceContext = INDEX_NONE;
TSharedPtr<int32, ESPMode::ThreadSafe> StyleTransferInferenceContext;
int32 StyleTransferStyleParamsInputIndex = INDEX_NONE;
void HandleConsoleVariableChanged(IConsoleVariable*);
void LoadNetworks();
};

View File

@ -20,6 +20,7 @@ class STYLETRANSFERSHADERS_API FOutputTensorToSceneColorCS : public FGlobalShade
BEGIN_SHADER_PARAMETER_STRUCT(FParameters, )
SHADER_PARAMETER(uint32, TensorVolume)
SHADER_PARAMETER(FIntPoint, TextureSize)
SHADER_PARAMETER_RDG_BUFFER_SRV(Buffer<float3>, InputTensor)
SHADER_PARAMETER_RDG_TEXTURE_UAV(RWTexture2D, OutputTexture)
END_SHADER_PARAMETER_STRUCT()