273 lines
11 KiB
C++
273 lines
11 KiB
C++
// Fill out your copyright notice in the Description page of Project Settings.
|
|
|
|
|
|
#include "StyleTransferSubsystem.h"
|
|
|
|
#include "IRenderCaptureProvider.h"
|
|
#include "NeuralNetwork.h"
|
|
#include "RenderGraphUtils.h"
|
|
#include "ScreenPass.h"
|
|
#include "StyleTransferModule.h"
|
|
#include "StyleTransferSceneViewExtension.h"
|
|
#include "StyleTransferSettings.h"
|
|
#include "TextureCompiler.h"
|
|
#include "Rendering/Texture2DResource.h"
|
|
|
|
TAutoConsoleVariable<bool> CVarStyleTransferEnabled(
|
|
TEXT("r.StyleTransfer.Enabled"),
|
|
false,
|
|
TEXT("Set to true to enable style transfer")
|
|
);
|
|
|
|
TAutoConsoleVariable<bool> CVarAutoCaptureStylePrediction(
|
|
TEXT("r.StyleTransfer.AutoCapturePrediction"),
|
|
false,
|
|
TEXT("Set to true to enable style transfer auto capture for profiling in PIX etc.")
|
|
);
|
|
|
|
|
|
void UStyleTransferSubsystem::Initialize(FSubsystemCollectionBase& Collection)
|
|
{
|
|
Super::Initialize(Collection);
|
|
|
|
CVarStyleTransferEnabled->OnChangedDelegate().AddUObject(this, &UStyleTransferSubsystem::HandleConsoleVariableChanged);
|
|
}
|
|
|
|
void UStyleTransferSubsystem::Deinitialize()
|
|
{
|
|
StopStylizingViewport();
|
|
|
|
Super::Deinitialize();
|
|
}
|
|
|
|
bool UStyleTransferSubsystem::Tick(float DeltaTime)
|
|
{
|
|
if (!GetWorld())
|
|
return true;
|
|
|
|
|
|
if (StylePredictionInferenceContexts.Num() > 1)
|
|
{
|
|
const UStyleTransferSettings* StyleTransferSettings = GetDefault<UStyleTransferSettings>();
|
|
const FRichCurve* InterpCurve = StyleTransferSettings->InterpolationCurve.GetRichCurveConst();
|
|
float MinTime, MaxTime;
|
|
InterpCurve->GetTimeRange(MinTime, MaxTime);
|
|
const double Time = MinTime + FMath::Fmod(GetWorld()->GetTimeSeconds(), static_cast<double>(MaxTime - MinTime));
|
|
const float Alpha = InterpCurve->Eval(Time);
|
|
UE_LOG(LogStyleTransfer, VeryVerbose, TEXT("Alpha is %0.4f"), Alpha);
|
|
InterpolateStyles(StylePredictionInferenceContexts[0], StylePredictionInferenceContexts[1], Alpha);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void UStyleTransferSubsystem::StartStylizingViewport(FViewportClient* ViewportClient)
|
|
{
|
|
if (!StylePredictionNetwork->IsLoaded() || !StyleTransferNetwork->IsLoaded())
|
|
{
|
|
UE_LOG(LogStyleTransfer, Error, TEXT("Not all networks were loaded, can not stylize viewport."));
|
|
return;
|
|
}
|
|
|
|
if (!StyleTransferSceneViewExtension)
|
|
{
|
|
const UStyleTransferSettings* StyleTransferSettings = GetDefault<UStyleTransferSettings>();
|
|
|
|
if (!StyleTransferInferenceContext || *StyleTransferInferenceContext == INDEX_NONE)
|
|
{
|
|
StyleTransferInferenceContext = MakeShared<int32>(StyleTransferNetwork->CreateInferenceContext());
|
|
checkf(*StyleTransferInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StyleTransferNetwork"));
|
|
}
|
|
|
|
for (int32 i = 0; i < FMath::Min(2, StyleTransferSettings->StyleTextures.Num()); ++i)
|
|
{
|
|
const int32& StylePredictionInferenceContext = StylePredictionInferenceContexts.Emplace_GetRef(StylePredictionNetwork->CreateInferenceContext());
|
|
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Could not create inference context for StylePredictionNetwork"));
|
|
|
|
UTexture2D* StyleTexture = StyleTransferSettings->StyleTextures[i].LoadSynchronous();
|
|
//UTexture2D* StyleTexture = LoadObject<UTexture2D>(this, TEXT("/Script/Engine.Texture2D'/StyleTransfer/T_StyleImage.T_StyleImage'"));
|
|
#if WITH_EDITOR
|
|
FTextureCompilingManager::Get().FinishCompilation({StyleTexture});
|
|
#endif
|
|
UpdateStyle(StyleTexture, StylePredictionInferenceContext);
|
|
}
|
|
//UpdateStyle(FPaths::GetPath("C:\\projects\\realtime-style-transfer\\temp\\style_params_tensor.bin"));
|
|
StyleTransferSceneViewExtension = FSceneViewExtensions::NewExtension<FStyleTransferSceneViewExtension>(ViewportClient->GetWorld(), ViewportClient, StyleTransferNetwork, StyleTransferInferenceContext.ToSharedRef());
|
|
}
|
|
StyleTransferSceneViewExtension->SetEnabled(true);
|
|
}
|
|
|
|
void UStyleTransferSubsystem::StopStylizingViewport()
|
|
{
|
|
FlushRenderingCommands();
|
|
StyleTransferSceneViewExtension.Reset();
|
|
if (StylePredictionInferenceContexts.Num())
|
|
{
|
|
for (auto It = StylePredictionInferenceContexts.CreateIterator(); It; ++It)
|
|
{
|
|
StylePredictionNetwork->DestroyInferenceContext(*It);
|
|
It.RemoveCurrent();
|
|
}
|
|
}
|
|
if (StyleTransferInferenceContext && *StyleTransferInferenceContext != INDEX_NONE)
|
|
{
|
|
StyleTransferNetwork->DestroyInferenceContext(*StyleTransferInferenceContext);
|
|
*StyleTransferInferenceContext = INDEX_NONE;
|
|
StyleTransferInferenceContext.Reset();
|
|
}
|
|
}
|
|
|
|
void UStyleTransferSubsystem::UpdateStyle(UTexture2D* StyleTexture, int32 StylePredictionInferenceContext)
|
|
{
|
|
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not infer style without inference context"));
|
|
checkf(StylePredictionInferenceContext != INDEX_NONE, TEXT("Can not update style without inference context"));
|
|
FlushRenderingCommands();
|
|
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StyleTexture, StylePredictionInferenceContext](FRHICommandListImmediate& RHICommandList)
|
|
{
|
|
IRenderCaptureProvider* RenderCaptureProvider = ConditionalBeginRenderCapture(RHICommandList);
|
|
|
|
FRDGBuilder GraphBuilder(RHICommandList);
|
|
{
|
|
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
|
|
|
|
const FNeuralTensor& InputStyleImageTensor = StylePredictionNetwork->GetInputTensorForContext(StylePredictionInferenceContext, 0);
|
|
FTextureResource* StyleTextureResource = StyleTexture->GetResource();
|
|
FRDGTextureRef RDGStyleTexture = GraphBuilder.RegisterExternalTexture(CreateRenderTarget(StyleTextureResource->TextureRHI, TEXT("StyleInputTexture")));
|
|
FStyleTransferSceneViewExtension::TextureToTensor(GraphBuilder, RDGStyleTexture, InputStyleImageTensor);
|
|
|
|
StylePredictionNetwork->Run(GraphBuilder, StylePredictionInferenceContext);
|
|
|
|
const FNeuralTensor& OutputStyleParams = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContext, 0);
|
|
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
|
|
|
FRDGBufferRef OutputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(OutputStyleParams.GetPooledBuffer());
|
|
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
|
|
const uint64 NumBytes = OutputStyleParams.NumInBytes();
|
|
check(OutputStyleParamsBuffer->GetSize() == InputStyleParamsBuffer->GetSize());
|
|
check(OutputStyleParamsBuffer->GetSize() == OutputStyleParams.NumInBytes());
|
|
check(InputStyleParamsBuffer->GetSize() == InputStyleParams.NumInBytes());
|
|
|
|
AddCopyBufferPass(GraphBuilder, InputStyleParamsBuffer, OutputStyleParamsBuffer);
|
|
}
|
|
GraphBuilder.Execute();
|
|
|
|
if (RenderCaptureProvider)
|
|
{
|
|
RenderCaptureProvider->EndCapture(&RHICommandList);
|
|
}
|
|
});
|
|
|
|
FlushRenderingCommands();
|
|
}
|
|
|
|
void UStyleTransferSubsystem::UpdateStyle(FString StyleTensorDataPath)
|
|
{
|
|
FArchive& FileReader = *IFileManager::Get().CreateFileReader(*StyleTensorDataPath);
|
|
TArray<float> StyleParams;
|
|
StyleParams.SetNumUninitialized(192);
|
|
|
|
FileReader << StyleParams;
|
|
|
|
ENQUEUE_RENDER_COMMAND(StyleParamsLoad)([this, StyleParams = MoveTemp(StyleParams)](FRHICommandListImmediate& RHICommandList)
|
|
{
|
|
const FNeuralTensor& InputStyleParams = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
|
|
|
FRDGBuilder GraphBuilder(RHICommandList);
|
|
{
|
|
RDG_EVENT_SCOPE(GraphBuilder, "StyleParamsLoad");
|
|
|
|
FRDGBufferRef InputStyleParamsBuffer = GraphBuilder.RegisterExternalBuffer(InputStyleParams.GetPooledBuffer());
|
|
GraphBuilder.QueueBufferUpload(InputStyleParamsBuffer, StyleParams.GetData(), StyleParams.Num() * StyleParams.GetTypeSize(), ERDGInitialDataFlags::NoCopy);
|
|
}
|
|
GraphBuilder.Execute();
|
|
});
|
|
FlushRenderingCommands();
|
|
}
|
|
|
|
void UStyleTransferSubsystem::HandleConsoleVariableChanged(IConsoleVariable* ConsoleVariable)
|
|
{
|
|
check(ConsoleVariable == CVarStyleTransferEnabled.AsVariable());
|
|
|
|
StopStylizingViewport();
|
|
|
|
if (CVarStyleTransferEnabled->GetBool())
|
|
{
|
|
if (!(StyleTransferNetwork || StylePredictionNetwork))
|
|
{
|
|
LoadNetworks();
|
|
}
|
|
StartStylizingViewport(GetGameInstance()->GetGameViewportClient());
|
|
}
|
|
}
|
|
|
|
void UStyleTransferSubsystem::LoadNetworks()
|
|
{
|
|
const UStyleTransferSettings* StyleTransferSettings = GetDefault<UStyleTransferSettings>();
|
|
StyleTransferNetwork = StyleTransferSettings->StyleTransferNetwork.LoadSynchronous();
|
|
StylePredictionNetwork = StyleTransferSettings->StylePredictionNetwork.LoadSynchronous();
|
|
|
|
if (StyleTransferNetwork->IsLoaded())
|
|
{
|
|
for (int32 i = 0; i < StyleTransferNetwork->GetInputTensorNumber(); ++i)
|
|
{
|
|
const FNeuralTensor& InputTensor = StyleTransferNetwork->GetInputTensor(i);
|
|
if (InputTensor.GetName() != "style_params")
|
|
continue;
|
|
|
|
StyleTransferStyleParamsInputIndex = i;
|
|
break;
|
|
}
|
|
StyleTransferNetwork->SetDeviceType(ENeuralDeviceType::GPU, ENeuralDeviceType::GPU, ENeuralDeviceType::GPU);
|
|
}
|
|
else
|
|
{
|
|
UE_LOG(LogStyleTransfer, Error, TEXT("StyleTransferNetwork could not be loaded"));
|
|
}
|
|
|
|
|
|
if (StylePredictionNetwork->IsLoaded())
|
|
{
|
|
StyleTransferNetwork->SetDeviceType(ENeuralDeviceType::GPU, ENeuralDeviceType::GPU, ENeuralDeviceType::GPU);
|
|
}
|
|
else
|
|
{
|
|
UE_LOG(LogStyleTransfer, Error, TEXT("StylePredictionNetwork could not be loaded."));
|
|
}
|
|
}
|
|
|
|
void UStyleTransferSubsystem::InterpolateStyles(int32 StylePredictionInferenceContextA, int32 StylePredictionInferenceContextB, float Alpha)
|
|
{
|
|
checkf(StyleTransferInferenceContext.IsValid() && (*StyleTransferInferenceContext) != INDEX_NONE, TEXT("Can not transfer style without inference context"));
|
|
checkf(StylePredictionInferenceContexts.Contains(StylePredictionInferenceContextA), TEXT("Can not update style without inference context A"));
|
|
checkf(StylePredictionInferenceContexts.Contains(StylePredictionInferenceContextB), TEXT("Can not update style without inference context B"));
|
|
ENQUEUE_RENDER_COMMAND(StylePrediction)([this, StylePredictionInferenceContextA, StylePredictionInferenceContextB, Alpha](FRHICommandListImmediate& RHICommandList)
|
|
{
|
|
IRenderCaptureProvider* RenderCaptureProvider = ConditionalBeginRenderCapture(RHICommandList);
|
|
FRDGBuilder GraphBuilder(RHICommandList);
|
|
{
|
|
RDG_EVENT_SCOPE(GraphBuilder, "StylePrediction");
|
|
|
|
const FNeuralTensor& InputStyleImageTensorA = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContextA, 0);
|
|
const FNeuralTensor& InputStyleImageTensorB = StylePredictionNetwork->GetOutputTensorForContext(StylePredictionInferenceContextB, 0);
|
|
const FNeuralTensor& OutputStyleParamsTensor = StyleTransferNetwork->GetInputTensorForContext(*StyleTransferInferenceContext, StyleTransferStyleParamsInputIndex);
|
|
FStyleTransferSceneViewExtension::InterpolateTensors(GraphBuilder, OutputStyleParamsTensor, InputStyleImageTensorA, InputStyleImageTensorB, Alpha);
|
|
}
|
|
GraphBuilder.Execute();
|
|
if(RenderCaptureProvider) RenderCaptureProvider->EndCapture(&RHICommandList);
|
|
});
|
|
}
|
|
|
|
IRenderCaptureProvider* UStyleTransferSubsystem::ConditionalBeginRenderCapture(FRHICommandListImmediate& RHICommandList)
|
|
{
|
|
IRenderCaptureProvider* RenderCaptureProvider = nullptr;
|
|
if (CVarAutoCaptureStylePrediction.GetValueOnRenderThread())
|
|
{
|
|
const FName RenderCaptureProviderType = IRenderCaptureProvider::GetModularFeatureName();
|
|
if (IModularFeatures::Get().IsModularFeatureAvailable(RenderCaptureProviderType))
|
|
{
|
|
RenderCaptureProvider = &IModularFeatures::Get().GetModularFeature<IRenderCaptureProvider>(RenderCaptureProviderType);
|
|
RenderCaptureProvider->BeginCapture(&RHICommandList);
|
|
}
|
|
}
|
|
return RenderCaptureProvider;
|
|
}
|