#1 - Import bibliotek
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.models import Model
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

#2 - Funkcja ładująca grafikę
def load_image(path, max_dim=512):
    img = Image.open(path).convert('RGB')
    scale = max_dim / max(img.size)
    new_size = (int(img.size[0] * scale), int(img.size[1] * scale))
    img = img.resize(new_size, Image.Resampling.LANCZOS)
    img = np.array(img).astype(np.float32)
    return img

#3 - Przetwarzanie obrazu
def preprocess(image):
    image = tf.convert_to_tensor(image, dtype=tf.float32)
    return preprocess_input(image)

def gram_matrix(tensor):
    result = tf.linalg.einsum('bijc,bijd->bcd', tensor, tensor)
    input_shape = tf.shape(tensor)
    num_locations = tf.cast(input_shape [1] * input_shape [2], tf.float32)
    return result / num_locations

#4 - Wczytanie obrazów stylu i treści
content_path = 'content.jpg'
style_path = 'style.jpg'

content_image = load_image(content_path)
style_image = load_image(style_path)

#5 - Wybór warst stylu i treści
layer_names = [
    'block_conv1', 'block2_conv1',
    'block3_conv1', 'block4_conv1',
    'block5_conv1'
]

style_layers = layer_names[:-1]
content_layer = layer_names[-1]

#6 - Budowa modelu i ekstrakcja cech
vgg = VGG16(include_top=False, weights='imagenet')
vgg.trainable = False
outputs = [vgg.get_layer(name).output for name in layer_names]
extractor_model = Model(inputs=vgg.input, outputs=outputs)

def extract_features(image):
    if image.ndim == 3:
        image = tf.expand_dims(image, axis=0)
    image = preprocess(image)
    outputs = extractor_model(image)
    return {name: output for name, output in zip(layer_names, outputs)}

style_targets = extract_features(style_image)
content_targets = extract_features(content_image)
style_grams = {name: gram_matrix(style_targets[name]) for name in style_layers}

#6 - Budowa modelu i ekstrakcja cech
vgg = VGG16(include_top=False, weights='imagenet')
vgg.trainable = False
outputs = [vgg.get_layer(name).output for name in layer_names]
extractor_model = Model(inputs=vgg.input, outputs=outputs)

def extract_features(image):
    if image.ndim == 3:
        image = tf.expand_dims(image, axis=0)
    image = preprocess(image)
    outputs = extractor_model(image)
    return {name: output for name, output in zip(layer_names, outputs)}

style_targets = extract_features(style_image)
content_targets = extract_features(content_image)
style_grams = {name: gram_matrix(style_targets[name]) for name in style_layers}


#7 - Inicjalizacja modelu
image = tf.Variable(content_image / 255.0)
style_weight = 1e4
content_weight = 10
opt = tf.optimizers.Adam(learning_rate=0.02)

#8 - Funkcja treningu
def train_step(image):
    with tf.GradientTape() as tape:
        outputs = extract_features(image * 255.0)
        content_output = outputs[content_layer]
        style_output = {name: outputs[name] for name in style_layers}

        content_loss = tf.reduce_mean((content_output - content_targets[content_layer]) ** 2)
        style_loss = tf.add_n([
            tf.reduce_mean((gram_matrix(style_output[name]) - style_grams[name]) ** 2)
            for name in style_layers
        ])

        total_loss = style_weight * style_loss + content_weight * content_loss

    grad = tape.gradient(total_loss, image)
    opt.apply_gradients([(grad, image)])
    image.assign(tf.clip_by_value(image, 0.0, 1.0))