from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, field_validator
from typing import List, Optional
import pandas as pd
import pickle
import os

# === PATH MODEL ===
HERE = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(HERE, "model")
MODEL_PATH = os.path.join(MODEL_DIR, "svm.pkl")
LE_JK_PATH = os.path.join(MODEL_DIR, "le_jenis_kelamin.pkl")
LE_STATUS_PATH = os.path.join(MODEL_DIR, "le_status_gizi.pkl")

FEATURE_ORDER = ["tinggi", "berat", "umur_bulan", "jenis_kelamin_encoded"]

GENDER_ALIASES = {"L": "L", "P": "P", "l": "L", "p": "P", "Male": "L", "Female": "P", "M": "L", "F": "P"}

# === Input pakai camelCase ===
class Item(BaseModel):
    tinggiCm: float
    beratKg: float
    usiaBulan: int
    jenisKelamin: str
    jenisKelaminEncoded: Optional[int] = None

    @field_validator("jenisKelamin")
    @classmethod
    def normalisasi_jk(cls, v):
        if v is None:
            return v
        v = str(v).strip()
        return GENDER_ALIASES.get(v, GENDER_ALIASES.get(v.upper(), v.upper()[0]))

    def to_features(self, le_jk):
        try:
            jk_encoded = (
                int(self.jenisKelaminEncoded)
                if self.jenisKelaminEncoded is not None
                else int(le_jk.transform([self.jenisKelamin])[0])
            )
        except Exception:
            raise ValueError(f"Jenis kelamin '{self.jenisKelamin}' tidak dikenali encoder")

        return {
            "tinggi": float(self.tinggiCm),
            "berat": float(self.beratKg),
            "umur_bulan": int(self.usiaBulan),
            "jenis_kelamin_encoded": jk_encoded,
        }

class BatchRequest(BaseModel):
    data: List[Item]

# === Fungsi bantu ===
def load_pickle(path):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Tidak ditemukan: {path}")
    with open(path, "rb") as f:
        return pickle.load(f)

# === Muat model & encoder ===
try:
    model = load_pickle(MODEL_PATH)
    le_jk = load_pickle(LE_JK_PATH)
    le_status = load_pickle(LE_STATUS_PATH)
    _load_error = None
except Exception as e:
    model = le_jk = le_status = None
    _load_error = e

app = FastAPI(
    title="API Prediksi Status Gizi (camelCase)",
    description="API untuk prediksi status gizi anak",
    version="1.0.0"
)

@app.get("/")
def root():
    return {"message": "API Prediksi Status Gizi", "status": "running"}

@app.get("/health")
def health():
    return {"status": "ok"} if _load_error is None else {"status": "error", "detail": str(_load_error)}

@app.post("/predict")
def predict(item: Item):
    if any(x is None for x in (model, le_jk, le_status)):
        raise HTTPException(status_code=500, detail=f"Gagal memuat model: {_load_error}")
    try:
        row = item.to_features(le_jk)
        X = pd.DataFrame([row])[FEATURE_ORDER]
        y_enc = int(model.predict(X)[0])
        y_label = str(le_status.inverse_transform([y_enc])[0])
        return {"label": y_label, "labelEncoded": y_enc}
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.post("/predictBatch")
def predict_batch(req: BatchRequest):
    if any(x is None for x in (model, le_jk, le_status)):
        raise HTTPException(status_code=500, detail=f"Gagal memuat model: {_load_error}")
    try:
        rows = [item.to_features(le_jk) for item in req.data]
        X = pd.DataFrame(rows)[FEATURE_ORDER]
        y_enc = model.predict(X)
        y_label = le_status.inverse_transform(y_enc)
        results = [{"label": str(lbl), "labelEncoded": int(enc)} for lbl, enc in zip(y_label, y_enc)]
        return {"results": results}
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))
