Rewriting a Deep Generative Model, ECCV 2020 (oral). Interactive tool to directly edit the rules of a GAN to synthesize scenes with objects added, removed, or altered. Change StyleGANv2 to make extravagant eyebrows, or horses wearing hats.
MIT License
In this paper, we ask if a deep network can be reprogrammed to follow different rules, by enabling a user to directly change the weights, instead of training with a data set.
We present the task of model rewriting, which aims to add, remove, and alter the semantic and physical rules of a pre-trained deep network. While modern image editing tools achieve a user-specified goal by manipulating individual input images, we enable a user to synthesize an unbounded number of new images by editing a generative model to carry out modified rules.
There are two reasons to want to rewrite a deep network directly:
Model rewriting envisions a way to construct deep networks according to a user's intentions. Rather than limiting networks to imitating data that we already have, rewriting allows deep networks to model a world that follows new rules that a user wishes to have.
Rewriting a Deep Generative Model. David Bau, Steven Liu, Tongzhou Wang, Jun-Yan Zhu, Antonio Torralba. ECCV 2020 (oral). MIT CSAIL and Adobe Research.
Our method rewrites the weights of a generator to change generative rules. Instead of editing individual images, our method edits the generator, so an infinite set of images can be potentially synthesized and manipulated using the altered rules. Rules can be changed in various ways, such as removing patterns like watermarks, adding objects such as people, or replacing definitions such as making trees grow out of towers.
Our method is based on the hypothesis that the weights of a generator act as linear associative memory. A layer stores a map between keys, which denote meaningful context, and values, which determine output.
The results below show changes of a single rule within StyleGANv2. In each case, four examples chosen by the user (center of the top row) establish the context for the rule begin rewritten, and the "copy and paste" examples (left and right of top row) indicate how the user wishes to change the model.
The grid below shows pairs of outputs: for each pair, the first is the output of the original unmodified StyleGANv2. The second is the output of the modified StyleGANv2, applying the user's intention using our method.
First: changing the rule defining kids' eyebrows to make them look like a bushy mustache.
Altering the rule for pointy tower tops to make them into trees.
Changing the rule for tops of horses heads, to put hats on horses.
Changing frowns into smiles.
Removing the main window in a building by changing the rule to draw a blank wall.
The code runs using PyTorch.
/rewrite
/notebooks
: see rewriting-interface.ipynb
for the demonstration UI./metrics
, dissection utilities in /utils
./experiments.sh
,/watermarks.sh
. The experiment/notebooks/reflection-rule-change.ipynb
.It's designed to use a recent version of PyTorch (1.4+) on python (3.6), using
cuda 10.1 and cudnn 7.6.0. Run setup/setup_renv.sh
to create a conda environment
that has the needed dependencies.
To edit your own models, do the following:
convert_weight.py
utility# Resolution (size) and style dimensionality (style_dim and n_mlp) are
# the architecture dimensions as you trained them. The truncation trick can be
# applied here if desired (truncation=1.0 if not).
# Note that mconv='seq' splits apart the optimized modulated convolution into
# separate operations that the rewriter can examine the underlying
# convolution directly.
model = SeqStyleGAN2(size=256, style_dim=512, n_mlp=8, truncation=0.5, mconv='seq')
# load the exponential moving average model weights, put it on the GPU.
state_dict = torch.load('your_model.pt')
model.load_state_dict(state_dict['g_ema'], latent_avg=state_dict['latent_avg'])
model.cuda()
ganrewrite.SeqStyleGanRewriter
instance to edit your modellayernum = 8 # or which ever layer you wish to modify
sample_size = 1000 # a small sample of images for computing statistics
zds = zdataset.z_dataset_for_model(model, size=sample_size)
gw = SeqStyleGanRewriter(
model, zds, layernum,
cachedir='experiments')
rewriteapp.GanRewriteApp
interface (assumes you are runningsavedir = 'masks'
interface = rewriteapp.GanRewriteApp(gw, size=256, mask_dir=savedir, num_canvases=32)
metrics/sample_edited.py
for an example.saved_edit = 'masks/my_edit.json'
gw.apply_edit(json.load(saved_edit), rank=1)