50 lines
1.9 KiB
Python
50 lines
1.9 KiB
Python
|
import os
|
||
|
from pathlib import Path
|
||
|
from typing import Optional
|
||
|
import tensorflow_hub as hub
|
||
|
import tensorflow as tf
|
||
|
from keras import backend as be
|
||
|
from matplotlib import pyplot as plt
|
||
|
import numpy as np
|
||
|
import cv2
|
||
|
|
||
|
model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
|
||
|
|
||
|
|
||
|
def load_image(img_path: str) -> (Optional[Path], Optional[object]):
|
||
|
if os.path.isfile(img_path):
|
||
|
img = tf.io.read_file(img_path)
|
||
|
img = tf.image.decode_image(img, channels=3)
|
||
|
img = tf.image.convert_image_dtype(img, tf.float32)
|
||
|
img = img[tf.newaxis, :]
|
||
|
return Path(img_path), img
|
||
|
else:
|
||
|
print(f"Couldn't find {img_path}!")
|
||
|
return None, None
|
||
|
|
||
|
|
||
|
while True:
|
||
|
content_image = None
|
||
|
content_path: Optional[Path] = None
|
||
|
# while content_image is None:
|
||
|
# content_path, content_image = load_image(f'inputs/{input("Content image: ")}')
|
||
|
# style_image, style_path = None, None
|
||
|
# while style_image is None:
|
||
|
# style_path, style_image = load_image(f'styles/{input("Style image: ")}')
|
||
|
|
||
|
style_path, style_image = load_image(f'styles/{input("Style image: ")}')
|
||
|
print("Processing...")
|
||
|
for file in os.listdir("inputs"):
|
||
|
content_path, content_image = load_image(f'inputs/{file}')
|
||
|
stylized_image = model(tf.constant(content_image), tf.constant(style_image))[0]
|
||
|
out = "outputs/" + style_path.stem + "_" + content_path.stem + ".jpg"
|
||
|
cv2.imwrite(out, cv2.cvtColor(np.squeeze(stylized_image) * 255, cv2.COLOR_BGR2RGB))
|
||
|
print(f"Done! Wrote file to {out}.")
|
||
|
be.clear_session()
|
||
|
tf.keras.backend.clear_session()
|
||
|
|
||
|
# stylized_image = model(tf.constant(content_image), tf.constant(style_image))[0]
|
||
|
# out = "outputs/" + os.path.splitext(content_path.name)[0] + "_" + os.path.splitext(style_path.name)[0] + ".jpg"
|
||
|
# cv2.imwrite(out, cv2.cvtColor(np.squeeze(stylized_image) * 255, cv2.COLOR_BGR2RGB))
|
||
|
# print(f"Done! Wrote file to {out}.")
|