Skip to content

Commit

Permalink
v0.0.2 still image
Browse files Browse the repository at this point in the history
  • Loading branch information
vinthony committed Apr 8, 2023
1 parent 8603ff9 commit 479a5ad
Show file tree
Hide file tree
Showing 21 changed files with 200 additions and 112 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,6 @@ cython_debug/
examples/results/*
gfpgan/*
checkpoints/
results/*
results/*
Dockerfile
start_docker.sh
45 changes: 0 additions & 45 deletions Dockerfile

This file was deleted.

1 change: 1 addition & 0 deletions checkpoints
Empty file.
Binary file added docs/sadtalker_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/still1.mp4
Binary file not shown.
Binary file removed docs/still_e.mp4
Binary file not shown.
Binary file removed docs/still_full_e.mp4
Binary file not shown.
Binary file added examples/source_image/full3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/source_image/full4.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 16 additions & 12 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from src.utils.paste_pic import paste_pic

def main(args):
#torch.backends.cudnn.enabled = False
Expand Down Expand Up @@ -43,8 +42,13 @@ def main(args):
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')

free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')

if args.preprocess == 'full':
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml')
else:
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')

#init model
print(path_of_net_recon_model)
Expand Down Expand Up @@ -92,7 +96,7 @@ def main(args):
ref_pose_coeff_path=None

#audio2ceoff
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path)
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)

# 3dface render
Expand All @@ -103,16 +107,16 @@ def main(args):
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
expression_scale=args.expression_scale, still_mode=args.still)
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess)

animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, full_img_enhancer=args.full_img_enhancer)
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess)

if __name__ == '__main__':

parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/bus_chinese.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full_body_2.png', help="path to source image")
parser.add_argument("--driven_audio", default='./examples/driven_audio/eluosi.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full3.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
Expand All @@ -123,12 +127,12 @@ def main(args):
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan]")
parser.add_argument('--full_img_enhancer', type=str, default=None, help="Full image enhancer, [gfpgan]")
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize'] )
parser.add_argument("--still", action="store_true", help="can crop back to the orginal videos for the full body aniamtion")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize', 'full'], help="how to preprocess the images" )

# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
Expand Down
45 changes: 45 additions & 0 deletions src/config/facerender_still.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 73
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66

48 changes: 20 additions & 28 deletions src/facerender/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import warnings
from skimage import img_as_ubyte

warnings.filterwarnings('ignore')

import imageio
Expand All @@ -17,7 +18,7 @@
from pydub import AudioSegment
from src.utils.face_enhancer import enhancer as face_enhancer
from src.utils.paste_pic import paste_pic

from src.utils.videoio import save_video_with_watermark


class AnimateFromCoeff():
Expand Down Expand Up @@ -116,7 +117,7 @@ def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,

return checkpoint['epoch']

def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, full_img_enhancer=None):
def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):

source_image=x['source_image'].type(torch.FloatTensor)
source_semantics=x['source_semantics'].type(torch.FloatTensor)
Expand Down Expand Up @@ -165,17 +166,6 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, full_i
path = os.path.join(video_save_dir, 'temp_'+video_name)
imageio.mimsave(path, result, fps=float(25))

if enhancer:
video_name_enhancer = x['video_name'] + '_enhanced.mp4'
av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
enhanced_images = face_enhancer(result, method=enhancer)

if original_size:
enhanced_images = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in enhanced_images ]

imageio.mimsave(enhanced_path, enhanced_images, fps=float(25))

av_path = os.path.join(video_save_dir, video_name)
return_path = av_path

Expand All @@ -190,27 +180,29 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, full_i
word = word1[start_time:end_time]
word.export(new_audio_path, format="wav")

cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (path, new_audio_path, av_path)
os.system(cmd)
save_video_with_watermark(path, new_audio_path, av_path, watermark= None)
print(f'The generated video is named {video_name} in {video_save_dir}')

if enhancer:
return_path = av_path_enhancer
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (enhanced_path, new_audio_path, av_path_enhancer)
os.system(cmd)
os.remove(enhanced_path)
print(f'The generated video is named {video_name_enhancer} in {video_save_dir}')

if len(crop_info) == 3:
if preprocess.lower() == 'full':
# only add watermark to the full image.
video_name_full = x['video_name'] + '_full.mp4'
full_video_path = os.path.join(video_save_dir, video_name_full)
return_path = full_video_path
if enhancer:
paste_pic(av_path_enhancer, pic_path, crop_info, new_audio_path, full_video_path)
else:
paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
print(f'The generated video is named {video_name_full} in {video_save_dir}')
paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
print(f'The generated video is named {video_save_dir}/{video_name_full}')

#### paste back then enhancers
if enhancer:
video_name_enhancer = x['video_name'] + '_enhanced.mp4'
enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer)
return_path = av_path_enhancer
enhanced_images = face_enhancer(full_video_path, method=enhancer, bg_upsampler=background_enhancer)
imageio.mimsave(enhanced_path, enhanced_images, fps=float(25))

save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= None)
print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
os.remove(enhanced_path)

os.remove(path)
os.remove(new_audio_path)
Expand Down
4 changes: 4 additions & 0 deletions src/facerender/modules/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def forward(self, source_image, kp_driving, kp_source):
kp_source=kp_source)
output_dict['mask'] = dense_motion['mask']

# import pdb; pdb.set_trace()

if 'occlusion_map' in dense_motion:
occlusion_map = dense_motion['occlusion_map']
output_dict['occlusion_map'] = occlusion_map
Expand All @@ -238,6 +240,8 @@ def forward(self, source_image, kp_driving, kp_source):
out = self.third(out)
out = self.fourth(out)

# occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)

if occlusion_map is not None:
if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
Expand Down
9 changes: 7 additions & 2 deletions src/generate_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def generate_blink_seq_randomly(num_frames):
break
return ratio

def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path):
def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False):

syncnet_mel_step_size = 16
fps = 25
Expand Down Expand Up @@ -95,7 +95,12 @@ def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path):
ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64]

indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
ratio = torch.FloatTensor(ratio).unsqueeze(0) # bs T

if still:
ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.) # bs T
else:
ratio = torch.FloatTensor(ratio).unsqueeze(0)
# bs T
ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0) # bs 1 70

indiv_mels = indiv_mels.to(device)
Expand Down
12 changes: 9 additions & 3 deletions src/generate_facerender_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None,
expression_scale=1.0, still_mode = False):
expression_scale=1.0, still_mode = False, preprocess='crop'):

semantic_radius = 13
video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
Expand All @@ -25,7 +25,12 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
data['source_image'] = source_image_ts

source_semantics_dict = scio.loadmat(first_coeff_path)
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70

if preprocess.lower() != 'full':
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70] #1 70
else:
source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73] #1 70

source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
source_semantics_ts = source_semantics_ts.repeat(batch_size, 1, 1)
Expand All @@ -37,6 +42,7 @@ def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path,
generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale

if still_mode:
generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
generated_3dmm[:, 64:] = np.repeat(source_semantics[:, 64:], generated_3dmm.shape[0], axis=0)

with open(txt_path+'.txt', 'w') as f:
Expand Down Expand Up @@ -82,7 +88,7 @@ def transform_semantic_1(semantic, semantic_radius):

def transform_semantic_target(coeff_3dmm, frame_index, semantic_radius):
num_frames = coeff_3dmm.shape[0]
seq = list(range(frame_index- semantic_radius, frame_index+ semantic_radius+1))
seq = list(range(frame_index- semantic_radius, frame_index + semantic_radius+1))
index = [ min(max(item, 0), num_frames-1) for item in seq ]
coeff_3dmm_g = coeff_3dmm[index, :]
return coeff_3dmm_g.transpose(1,0)
Expand Down
7 changes: 4 additions & 3 deletions src/utils/croper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def align_face(self, img, lm, output_size=1024):
# img_np_list[_i] = _inp
# return img_np_list

def crop(self, img_np_list, xsize=512): # first frame for all video
def crop(self, img_np_list, still=False, xsize=512): # first frame for all video
img_np = img_np_list[0]
lm = self.get_landmark(img_np)
if lm is None:
Expand All @@ -174,7 +174,8 @@ def crop(self, img_np_list, xsize=512): # first frame for all video
_inp = img_np_list[_i]
_inp = _inp[cly:cry, clx:crx]
# cv2.imwrite('test1.jpg', _inp)
_inp = _inp[ly:ry, lx:rx]
if not still:
_inp = _inp[ly:ry, lx:rx]
# cv2.imwrite('test2.jpg', _inp)
img_np_list[_i] = _inp
return img_np_list, crop, quad
Expand Down Expand Up @@ -292,4 +293,4 @@ def get_wra_data_path(video_dir):
device_ids = opt.device_ids.split(",")
device_ids = cycle(device_ids)
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
None
None
Loading

0 comments on commit 479a5ad

Please sign in to comment.