import json, random, math

#Wczytanie pliku JSON do listy rekordów
def load_data(path):
    """Zwraca listę słowników {label: ..., tags: [...] }"""
    with open(path, encoding="utf‑8") as f:
        return json.load(f)

#Podział na zbiór treningowy i testowy (80 / 20)
def train_test_split(data, test_ratio=0.20):
    random.shuffle(data)
    cut = int(len(data) * (1 - test_ratio))
    print(f"Ilosc probek zbioru uczącego: {len(data[:cut])}")
    print(f"Ilosc probek zbioru testowego: {len( data[cut:])}")
    return data[:cut], data[cut:]

# Budowa słownika wszystkich tagów z treningu
def build_vocabulary(train):
    vocab = set()
    for rec in train:
        vocab.update(rec["tags"])
    print (vocab)
    return vocab

#Trenowanie modelu klasyfikatora Bayesa
def train_nb(train, vocab, alpha=1.0):
    class_counts = {}
    word_counts  = {}
    total_words  = {}

    for rec in train:
        c = rec["label"]
        class_counts[c]  = class_counts.get(c, 0) + 1
        word_counts.setdefault(c, {})
        total_words.setdefault(c, 0)

        for tag in rec["tags"]:
            word_counts[c][tag] = word_counts[c].get(tag, 0) + 1
            total_words[c]     += 1

    model = {
        "class_counts": class_counts,
        "word_counts":  word_counts,
        "total_words":  total_words,
        "vocab":        vocab,
        "alpha":        alpha,
        "total_docs":   len(train)
    }
    print(model)
    return model

# Obliczanie log‑prawdopodobieństwa jednej klasy dla rekordu
def log_prob(model, rec, class_name):
    logp = math.log(model["class_counts"][class_name] / model["total_docs"])
    V    = len(model["vocab"])
    a    = model["alpha"]


    for tag in rec["tags"]:
        wc = model["word_counts"][class_name].get(tag, 0)
        logp += math.log((wc + a) / (model["total_words"][class_name] + a * V))
    return logp

#Predykcja – wybieramy klasę z najwyższym log‑prawdop.
def predict(model, rec):
    best_class, best_log = None, -1e99
    for c in model["class_counts"]:
        lp = log_prob(model, rec, c)
        if lp > best_log:
            best_class, best_log = c, lp
    return best_class

#Ewaluacja na zbiorze testowym
def evaluate(model, test):
    correct = 0
    for rec in test:
        if predict(model, rec) == rec["label"]:
            correct += 1
    accuracy = correct / len(test)
    print(f"Dokładność (accuracy) = {accuracy:.2%}")

    #Główna funkcja
def main():
    path = "tagi_gier_wspoldzielone.json"
    data = load_data(path)
    train, test = train_test_split(data, test_ratio=0.20)
    vocab = build_vocabulary(data)
    model = train_nb(train, vocab, alpha=1.0)

    sample = test[0]
    print("\nPrzykładowe tagi:", sample["tags"])
    print("Rzeczywista gra:", sample["label"])
    print("Model przewidział:", predict(model, sample))

    evaluate(model, test)


main()