import json
import random
import math

#---
# 1. Wczytywanie danych treningowych
#---

def load_email_data(spam_path, ham_path):
    with open(spam_path, encoding='utf-8') as f1:
        spam_data = json.load(f1)

    with open(ham_path, encoding='utf-8') as f2:
        ham_data = json.load(f2)

    # dodanie etykiet
    for rec in spam_data:
        rec["label"] = "spam"

    for rec in ham_data:
        rec["label"] = "ham"

    return spam_data + ham_data


def train_test_split(data, test_ratio=0.2):
    random.shuffle(data)
    cut = int(len(data) * (1 - test_ratio))
    return data[:cut], data[cut:]


#---
# 2. Trenowanie klasyfikatora Bayesa
#---

def preprocess(text):
    return text.lower().replace("-", " ").replace(".", " ").replace(",", "").replace("!", " ").replace("?", " ").split()


def train_naive_bayes(train_data, alpha=1.0):
    class_counts = {}
    word_counts = {}
    total_words = {}

    for rec in train_data:
        label = rec["label"]

        class_counts[label] = class_counts.get(label, 0) + 1
        word_counts.setdefault(label, {})
        total_words.setdefault(label, 0)

        words = preprocess(rec["text"])

        for word in words:
            word_counts[label][word] = word_counts[label].get(word, 0) + 1
            total_words[label] += 1

    vocab = set()
    for wc in word_counts.values():
        vocab.update(wc.keys())

    return {
        "class_counts": class_counts,
        "word_counts": word_counts,
        "total_words": total_words,
        "vocab": vocab,
        "alpha": alpha,
        "total_docs": len(train_data)
    }


#---
# 4. Główna funkcja
#---

from pprint import pprint

def main():
    data = load_email_data("Spam.json", "Ham.json")

    train, test = train_test_split(data)

    model = train_naive_bayes(train)

    pprint(model)


#---
# 5. Uruchomienie
#---

if __name__ == "__main__":
    main()