# ===
# Import bibliotek
# ===
import os
import fastf1
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt

# ===
# Tworzenie cache
# ===
if not os.path.exists('cache'):
    os.makedirs('cache')

fastf1.Cache.enable_cache('cache')

# ===
# Pobieranie danych kwalifikacji
# ===
qualy = fastf1.get_session(2026, 'Suzuka', 'R')
qualy.load()
laps_df = qualy.laps

# ===
# Tworzenie datasetu
# ===

# Najlepszy czas okrążenia dla każdego kierowcy
df = laps_df.groupby('Driver')['LapTime'].min().reset_index()

# Usunięcie ewentualnych braków danych
df = df.dropna()

# Konwersja czasu na sekundy (KLUCZOWE dla ML)
df['LapTimeSeconds'] = df['LapTime'].dt.total_seconds()

# Dodanie kolumny Top3
top3_max = df['LapTimeSeconds'].nsmallest(3).max()
df['Top3'] = (df['LapTimeSeconds'] <= top3_max).astype(int)

# ===
# Dodanie różnicy od najszybszego
# ===
best_time = df['LapTimeSeconds'].min()
df['DiffToBest'] = df['LapTimeSeconds'] - best_time

print("\nPodgląd danych z różnicą do najlepszego kierowcy:")
print(df)

# ===
# Analiza statystyczna
# ===
mean_time = df['LapTimeSeconds'].mean()
median_time = df['LapTimeSeconds'].median()
max_time = df['LapTimeSeconds'].max()
min_time = df['LapTimeSeconds'].min()

print("\nProsta analiza statystyczna:")
print(f"Średni czas: {mean_time:.3f} s")
print(f"Mediana czasu: {median_time:.3f} s")
print(f"Najszybszy czas: {min_time:.3f} s")
print(f"Najwolniejszy czas: {max_time:.3f} s")

# ===
# Trening modelu ML
# ===
X = df[['LapTimeSeconds', 'DiffToBest']]
Y = df['Top3']

model = DecisionTreeClassifier(random_state=42)
model.fit(X, Y)

print("\nModel został wytrenowany")

# ===
# Testowanie modelu
# ===
test_times = pd.DataFrame({
    'LapTimeSeconds': [min_time, median_time, max_time],
    'DiffToBest': [0, median_time - min_time, max_time - min_time]
})

pred = model.predict(test_times)

print("\nPredykcje TOP3 dla przykładowych czasów:")
for t, p in zip(test_times['LapTimeSeconds'], pred):
    print(f"Czas: {t:.3f}s -> Predykcja TOP3: {p}")

# ===
# Wizualizacja Top3 + różnicy
# ===
df_sorted = df.sort_values('LapTimeSeconds')
colors = df_sorted['Top3'].map({1: "red", 0: "blue"})

plt.figure(figsize=(12, 6))
bars = plt.bar(df_sorted['Driver'], df_sorted['LapTimeSeconds'], color=colors)

# Dodanie różnicy nad słupkami
for bar, diff in zip(bars, df_sorted['DiffToBest']):
    plt.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.1,
        f"{diff:.2f}s",
        ha='center',
        va='bottom',
        fontsize=9
    )

plt.title("Najlepsze czasy okrążeń - kwalifikacje (Suzuka 2026)")
plt.xlabel("Kierowca")
plt.ylabel("Czas okrążenia (s)")
plt.xticks(rotation=45)
plt.tight_layout()

plt.show()