Added runtime style prediction

This commit is contained in:
Manuel Wagner 2022-09-05 13:38:05 +02:00
parent 20144b8b59
commit b580ba6316
5 changed files with 68 additions and 29 deletions

Binary file not shown.

View File

@ -113,6 +113,7 @@ void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProces
void FStyleTransferSceneViewExtension::PreRenderViewFamily_RenderThread(FRDGBuilder& GraphBuilder, FSceneViewFamily& InViewFamily) void FStyleTransferSceneViewExtension::PreRenderViewFamily_RenderThread(FRDGBuilder& GraphBuilder, FSceneViewFamily& InViewFamily)
{ {
return;
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName(); const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
if(!IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType)) if(!IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
return; return;
@ -181,19 +182,19 @@ FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& Grap
return OutputTexture; return OutputTexture;
} }
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor) void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
{ {
const FIntVector InputTensorDimensions = { const FIntVector InputTensorDimensions = {
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)), CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)), CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)), CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
}; };
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture.Texture->Desc.Extent; const FIntPoint SceneColorRenderTargetDimensions = SourceTexture->Desc.Extent;
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer()); FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>(); auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num()); SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
SceneColorToInputTensorParameters->InputTexture = SourceTexture.Texture; SceneColorToInputTensorParameters->InputTexture = SourceTexture;
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI(); SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer); SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y}; SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
@ -239,7 +240,7 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0); const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
TextureToTensor(GraphBuilder, SceneColor, StyleTransferContentInputTensor); TextureToTensor(GraphBuilder, SceneColor.Texture, StyleTransferContentInputTensor);
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext); StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);

View File

@ -3,11 +3,13 @@
#include "StyleTransferSubsystem.h" #include "StyleTransferSubsystem.h"
#include "IRenderCaptureProvider.h"
#include "NeuralNetwork.h" #include "NeuralNetwork.h"
#include "RenderGraphUtils.h" #include "RenderGraphUtils.h"
#include "ScreenPass.h" #include "ScreenPass.h"
#include "StyleTransferModule.h" #include "StyleTransferModule.h"
#include "StyleTransferSceneViewExtension.h" #include "StyleTransferSceneViewExtension.h"
#include "TextureCompiler.h"
#include "Rendering/Texture2DResource.h" #include "Rendering/Texture2DResource.h"
TAutoConsoleVariable<bool> CVarStyleTransferEnabled( TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
@ -47,11 +49,14 @@ void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportCl
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork")); checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'")); UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
UpdateStyle(StyleTexture); FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
//UpdateStyle(StyleTexture);
UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef()); StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
} }
StyleTransferSceneViewExtension->SetEnabled(true); StyleTransferSceneViewExtension->SetEnabled(true);
ViewportClient->GetWorld()->GetWorldSettings()->SetPauserPlayerState(ViewportClient->GetWorld()->GetFirstPlayerController()->PlayerState);
} }
void UStyleTransferSubsystem::StopStylizingViewport() void UStyleTransferSubsystem::StopStylizingViewport()
@ -74,43 +79,75 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture)
{ {
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context")); checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context"));
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context")); checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context"));
FlushRenderingCommands();
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture](FRHICommandListImmediate& RHICommandList) ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture](FRHICommandListImmediate& RHICommandList)
{ {
FRDGBuilder GraphBuilder(RHICommandList); IRenderCaptureProvider* RenderCaptureProvider = nullptr;
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
if(IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
if(!StyleTextureResource->IsInitialized())
{ {
StyleTextureResource->UpdateRHI(); RenderCaptureProvider = &IModularFeatures::Get().GetModularFeature<IRenderCaptureProvider>(RenderCaptureProviderType);
RenderCaptureProvider->BeginCapture(&RHICommandList);
} }
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture"))); FRDGBuilder GraphBuilder(RHICommandList);
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, FScreenPassTexture(RDGStyleTexture), InputStyleImageTensor); {
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext); const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, RDGStyleTexture, InputStyleImageTensor);
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0); StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer()); const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer()); const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
const uint64 NumBytes = OutputStyleParams.NumInBytes();
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer); FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
const uint64 NumBytes = OutputStyleParams.NumInBytes();
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
}
GraphBuilder.Execute(); GraphBuilder.Execute();
if(RenderCaptureProvider)
{
RenderCaptureProvider->EndCapture(&RHICommandList);
}
}); });
FlushRenderingCommands(); FlushRenderingCommands();
} }
void UStyleTransferSubsystem::UpdateStyle(FString StyleTensorDataPath)
{
FArchive& FileReader = *IFileManager::Get().CreateFileReader(*StyleTensorDataPath);
TArray<float> StyleParams;
StyleParams.SetNumUninitialized(192);
FileReader << StyleParams;
ENQUEUE_RENDER_COMMAND(StyleParamsLoad)([this, StyleParams = MoveTemp(StyleParams)](FRHICommandListImmediate& RHICommandList)
{
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
FRDGBuilder GraphBuilder(RHICommandList);
{
RDG_EVENT_SCOPE(GraphBuilder, "StyleParamsLoad");
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
GraphBuilder.QueueBufferUpload(InputStyleParamsBuffer, StyleParams.GetData(), StyleParams.Num() * StyleParams.GetTypeSize(), ERDGInitialDataFlags::NoCopy);
}
GraphBuilder.Execute();
});
FlushRenderingCommands();
}
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable) void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
{ {
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable()); check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());

View File

@ -38,7 +38,7 @@ public:
static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget); static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor); static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
static void TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor); static void TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor);
private: private:
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */ /** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */

View File

@ -26,6 +26,7 @@ public:
void StopStylizingViewport(); void StopStylizingViewport();
void UpdateStyle(UTexture2D* StyleTexture); void UpdateStyle(UTexture2D* StyleTexture);
void UpdateStyle(FString StyleTensorDataPath);
private: private:
FStyleTransferSceneViewExtension::Ptr StyleTransferSceneViewExtension; FStyleTransferSceneViewExtension::Ptr StyleTransferSceneViewExtension;