Added runtime style prediction
This commit is contained in:
parent
20144b8b59
commit
b580ba6316
BIN
Plugins/StyleTransfer/Content/T_StyleImage.uasset (Stored with Git LFS)
BIN
Plugins/StyleTransfer/Content/T_StyleImage.uasset (Stored with Git LFS)
Binary file not shown.
|
@ -113,6 +113,7 @@ void FStyleTransferSceneViewExtension::SubscribeToPostProcessingPass(EPostProces
|
||||||
|
|
||||||
void FStyleTransferSceneViewExtension::PreRenderViewFamily_RenderThread(FRDGBuilder& GraphBuilder, FSceneViewFamily& InViewFamily)
|
void FStyleTransferSceneViewExtension::PreRenderViewFamily_RenderThread(FRDGBuilder& GraphBuilder, FSceneViewFamily& InViewFamily)
|
||||||
{
|
{
|
||||||
|
return;
|
||||||
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
|
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
|
||||||
if(!IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
|
if(!IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
|
||||||
return;
|
return;
|
||||||
|
@ -181,19 +182,19 @@ FRDGTexture* FStyleTransferSceneViewExtension::TensorToTexture(FRDGBuilder& Grap
|
||||||
return OutputTexture;
|
return OutputTexture;
|
||||||
}
|
}
|
||||||
|
|
||||||
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor)
|
void FStyleTransferSceneViewExtension::TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor)
|
||||||
{
|
{
|
||||||
const FIntVector InputTensorDimensions = {
|
const FIntVector InputTensorDimensions = {
|
||||||
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
|
CastNarrowingSafe<int32>(DestinationTensor.GetSize(1)),
|
||||||
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
|
CastNarrowingSafe<int32>(DestinationTensor.GetSize(2)),
|
||||||
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
|
CastNarrowingSafe<int32>(DestinationTensor.GetSize(3)),
|
||||||
};
|
};
|
||||||
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture.Texture->Desc.Extent;
|
const FIntPoint SceneColorRenderTargetDimensions = SourceTexture->Desc.Extent;
|
||||||
|
|
||||||
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
|
FRDGBufferRef StyleTransferContentInputBuffer = GraphBuilder.RegisterExternalBuffer(DestinationTensor.GetPooledBuffer());
|
||||||
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
|
auto SceneColorToInputTensorParameters = GraphBuilder.AllocParameters<FSceneColorToInputTensorCS::FParameters>();
|
||||||
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
|
SceneColorToInputTensorParameters->TensorVolume = CastNarrowingSafe<uint32>(DestinationTensor.Num());
|
||||||
SceneColorToInputTensorParameters->InputTexture = SourceTexture.Texture;
|
SceneColorToInputTensorParameters->InputTexture = SourceTexture;
|
||||||
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
|
SceneColorToInputTensorParameters->InputTextureSampler = TStaticSamplerState<SF_Bilinear>::GetRHI();
|
||||||
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
|
SceneColorToInputTensorParameters->OutputUAV = GraphBuilder.CreateUAV(StyleTransferContentInputBuffer);
|
||||||
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
|
SceneColorToInputTensorParameters->OutputDimensions = {InputTensorDimensions.X, InputTensorDimensions.Y};
|
||||||
|
@ -239,7 +240,7 @@ FScreenPassTexture FStyleTransferSceneViewExtension::PostProcessPassAfterTonemap
|
||||||
|
|
||||||
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
|
const FNeuralTensor& StyleTransferContentInputTensor = StyleTransferNetwork->GetInputTensorForContext(*InferenceContext, 0);
|
||||||
|
|
||||||
TextureToTensor(GraphBuilder, SceneColor, StyleTransferContentInputTensor);
|
TextureToTensor(GraphBuilder, SceneColor.Texture, StyleTransferContentInputTensor);
|
||||||
|
|
||||||
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);
|
StyleTransferNetwork->Run(GraphBuilder, *InferenceContext);
|
||||||
|
|
||||||
|
|
|
@ -3,11 +3,13 @@
|
||||||
|
|
||||||
#include "StyleTransferSubsystem.h"
|
#include "StyleTransferSubsystem.h"
|
||||||
|
|
||||||
|
#include "IRenderCaptureProvider.h"
|
||||||
#include "NeuralNetwork.h"
|
#include "NeuralNetwork.h"
|
||||||
#include "RenderGraphUtils.h"
|
#include "RenderGraphUtils.h"
|
||||||
#include "ScreenPass.h"
|
#include "ScreenPass.h"
|
||||||
#include "StyleTransferModule.h"
|
#include "StyleTransferModule.h"
|
||||||
#include "StyleTransferSceneViewExtension.h"
|
#include "StyleTransferSceneViewExtension.h"
|
||||||
|
#include "TextureCompiler.h"
|
||||||
#include "Rendering/Texture2DResource.h"
|
#include "Rendering/Texture2DResource.h"
|
||||||
|
|
||||||
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
|
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
|
||||||
|
@ -47,11 +49,14 @@ void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportCl
|
||||||
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
|
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
|
||||||
|
|
||||||
UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
|
UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
|
||||||
UpdateStyle(StyleTexture);
|
FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
|
||||||
|
//UpdateStyle(StyleTexture);
|
||||||
|
UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
|
||||||
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
|
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
|
||||||
|
|
||||||
}
|
}
|
||||||
StyleTransferSceneViewExtension->SetEnabled(true);
|
StyleTransferSceneViewExtension->SetEnabled(true);
|
||||||
|
ViewportClient->GetWorld()->GetWorldSettings()->SetPauserPlayerState(ViewportClient->GetWorld()->GetFirstPlayerController()->PlayerState);
|
||||||
}
|
}
|
||||||
|
|
||||||
void UStyleTransferSubsystem::StopStylizingViewport()
|
void UStyleTransferSubsystem::StopStylizingViewport()
|
||||||
|
@ -74,43 +79,75 @@ void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture)
|
||||||
{
|
{
|
||||||
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context"));
|
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context"));
|
||||||
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context"));
|
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context"));
|
||||||
|
FlushRenderingCommands();
|
||||||
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture](FRHICommandListImmediate& RHICommandList)
|
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture](FRHICommandListImmediate& RHICommandList)
|
||||||
{
|
{
|
||||||
FRDGBuilder GraphBuilder(RHICommandList);
|
IRenderCaptureProvider* RenderCaptureProvider = nullptr;
|
||||||
|
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
|
||||||
|
if(IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
|
||||||
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
|
|
||||||
|
|
||||||
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
|
|
||||||
if(!StyleTextureResource->IsInitialized())
|
|
||||||
{
|
{
|
||||||
StyleTextureResource->UpdateRHI();
|
RenderCaptureProvider = &IModularFeatures::Get().GetModularFeature<IRenderCaptureProvider>(RenderCaptureProviderType);
|
||||||
|
RenderCaptureProvider->BeginCapture(&RHICommandList);
|
||||||
}
|
}
|
||||||
|
|
||||||
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
|
FRDGBuilder GraphBuilder(RHICommandList);
|
||||||
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, FScreenPassTexture(RDGStyleTexture), InputStyleImageTensor);
|
{
|
||||||
|
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
|
||||||
|
|
||||||
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
|
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
|
||||||
|
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
|
||||||
|
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
|
||||||
|
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, RDGStyleTexture, InputStyleImageTensor);
|
||||||
|
|
||||||
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
|
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
|
||||||
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
|
||||||
|
|
||||||
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
|
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
|
||||||
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
|
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
||||||
const uint64 NumBytes = OutputStyleParams.NumInBytes();
|
|
||||||
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
|
|
||||||
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
|
|
||||||
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
|
|
||||||
|
|
||||||
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
|
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
|
||||||
|
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
|
||||||
|
const uint64 NumBytes = OutputStyleParams.NumInBytes();
|
||||||
|
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
|
||||||
|
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
|
||||||
|
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
|
||||||
|
|
||||||
|
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
|
||||||
|
}
|
||||||
GraphBuilder.Execute();
|
GraphBuilder.Execute();
|
||||||
|
|
||||||
|
if(RenderCaptureProvider)
|
||||||
|
{
|
||||||
|
RenderCaptureProvider->EndCapture(&RHICommandList);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
FlushRenderingCommands();
|
FlushRenderingCommands();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void UStyleTransferSubsystem::UpdateStyle(FString StyleTensorDataPath)
|
||||||
|
{
|
||||||
|
FArchive& FileReader = *IFileManager::Get().CreateFileReader(*StyleTensorDataPath);
|
||||||
|
TArray<float> StyleParams;
|
||||||
|
StyleParams.SetNumUninitialized(192);
|
||||||
|
|
||||||
|
FileReader << StyleParams;
|
||||||
|
|
||||||
|
ENQUEUE_RENDER_COMMAND(StyleParamsLoad)([this, StyleParams = MoveTemp(StyleParams)](FRHICommandListImmediate& RHICommandList)
|
||||||
|
{
|
||||||
|
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
||||||
|
|
||||||
|
FRDGBuilder GraphBuilder(RHICommandList);
|
||||||
|
{
|
||||||
|
RDG_EVENT_SCOPE(GraphBuilder, "StyleParamsLoad");
|
||||||
|
|
||||||
|
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
|
||||||
|
GraphBuilder.QueueBufferUpload(InputStyleParamsBuffer, StyleParams.GetData(), StyleParams.Num() * StyleParams.GetTypeSize(), ERDGInitialDataFlags::NoCopy);
|
||||||
|
}
|
||||||
|
GraphBuilder.Execute();
|
||||||
|
});
|
||||||
|
FlushRenderingCommands();
|
||||||
|
}
|
||||||
|
|
||||||
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
|
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
|
||||||
{
|
{
|
||||||
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());
|
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());
|
||||||
|
|
|
@ -38,7 +38,7 @@ public:
|
||||||
|
|
||||||
static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
|
static void AddRescalingTextureCopy(FRDGBuilder& GraphBuilder, FRDGTexture& RDGSourceTexture, FScreenPassRenderTarget& DestinationRenderTarget);
|
||||||
static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
|
static FRDGTexture* TensorToTexture(FRDGBuilder& GraphBuilder, const FRDGTextureDesc& BaseDestinationDesc, const FNeuralTensor& SourceTensor);
|
||||||
static void TextureToTensor(FRDGBuilder& GraphBuilder, const FScreenPassTexture& SourceTexture, const FNeuralTensor& DestinationTensor);
|
static void TextureToTensor(FRDGBuilder& GraphBuilder, FRDGTextureRef SourceTexture, const FNeuralTensor& DestinationTensor);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */
|
/** The actual Network pointer is not tracked so we need a WeakPtr too so we can check its validity on the game thread. */
|
||||||
|
|
|
@ -26,6 +26,7 @@ public:
|
||||||
void StopStylizingViewport();
|
void StopStylizingViewport();
|
||||||
|
|
||||||
void UpdateStyle(UTexture2D* StyleTexture);
|
void UpdateStyle(UTexture2D* StyleTexture);
|
||||||
|
void UpdateStyle(FString StyleTensorDataPath);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FStyleTransferSceneViewExtension::Ptr StyleTransferSceneViewExtension;
|
FStyleTransferSceneViewExtension::Ptr StyleTransferSceneViewExtension;
|
||||||
|
|
Loading…
Reference in New Issue