import os
import pickle
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from imblearn.over_sampling import SMOTE  # Handles class imbalance

# ✅ Define dataset and model paths
DATASET_PATH = os.path.abspath(os.path.join("data", "preprocessed_dataset.csv"))
MODEL_PATH = os.path.abspath(os.path.join("backend","models", "anomaly_detector.pkl"))  # Ensure correct path

# ✅ Load dataset
if not os.path.exists(DATASET_PATH):
    raise FileNotFoundError(f"❌ Dataset not found at: {DATASET_PATH}")

data = pd.read_csv(DATASET_PATH)
print("✅ Dataset loaded successfully!")
print("📝 Available Columns:", data.columns.tolist())

# ✅ Identify the target column dynamically
possible_targets = ["label", "target", "class", "is_malicious"]  
TARGET_COLUMN = next((col for col in possible_targets if col in data.columns), None)

if TARGET_COLUMN is None:
    raise KeyError(f"❌ No valid target column found! Expected one of {possible_targets}, but got: {data.columns.tolist()}")

# ✅ Handle missing values (if any)
if data.isnull().sum().sum() > 0:
    print("⚠️ Missing values detected! Filling with median values.")
    data.fillna(data.median(), inplace=True)

# ✅ Define features (X) and target (y)
X = data.drop(columns=[TARGET_COLUMN])
y = data[TARGET_COLUMN]

# ✅ Handle class imbalance using SMOTE
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X, y)

# ✅ Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)

# ✅ Train a RandomForestClassifier model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# ✅ Evaluate model performance
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print(f"\n✅ Model Training Completed!")
print(f"📊 Accuracy: {accuracy:.2f}")
print("📄 Classification Report:\n", classification_report(y_test, y_pred))

# ✅ Save the trained model properly
try:
    os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)  # Ensure the directory exists
    with open(MODEL_PATH, "wb") as f:
        pickle.dump(model, f)
    print(f"✅ Model saved successfully at: {MODEL_PATH}")
except Exception as e:
    print(f"❌ Error saving model: {e}")
