import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel

MODEL = "faisalq/EgyBERT"
device = torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModel.from_pretrained(MODEL).to(device)
model.eval()

app = FastAPI(title="EgyBERT Embeddings")

class EmbedRequest(BaseModel):
    texts: list[str]
    max_length: int = 256
    normalize: bool = True

def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = (last_hidden_state * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp(min=1e-9)
    return summed / counts

@app.post("/embed")
def embed(req: EmbedRequest):
    inputs = tokenizer(
        req.texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=req.max_length
    )

    with torch.no_grad():
        outputs = model(**inputs)
        vec = mean_pool(outputs.last_hidden_state, inputs["attention_mask"])
        if req.normalize:
            vec = F.normalize(vec, p=2, dim=1)

    return {"dim": vec.shape[1], "vectors": vec.cpu().tolist()}
