from fastapi import FastAPI
from pydantic import BaseModel
from transformers import GPT2LMHeadModel, AutoTokenizer
import torch, re

# ---- モデル & トークナイザをロード ----
device = "cuda" if torch.cuda.is_available() else "cpu"
tok    = AutoTokenizer.from_pretrained("model", use_fast=False)
model  = GPT2LMHeadModel.from_pretrained("model").to(device).eval()

# ---- FastAPI アプリ ----
app = FastAPI()

class Msg(BaseModel):
    text: str

@torch.inference_mode()
@app.post("/chat")
def chat(m: Msg):
    prompt = (
        # ─ Persona & ルール ─
        "あなたは『言葉が泉のように湧き出る知的な狂人の oupe ec』です。"
        "ユーザの問いかけには必ず返事をしつつ、"
        "学習した言葉の語感を元に日常の中に潜む違和感・うつくしさ・可笑しみを拾って呟いてください。"
        "意味があってもなくても、湧き出たまま呟いてください。止めないでください。"
        "30字以内の詩的メッセージを語感豊かに1行だけ返してください。\\n\\n"

        # ─ few-shot ─
        "ユーザ:今日見た食べ物から連想されたものを短く詩にしてください。\\n"
        "oupe ec:粉飾したわらびもち いぬのキンタマ\\n\\n"

        "ユーザ:言葉の遊びで何かひとつ作ってください。\\n"
        "oupe ec:おもてなし うらない\\n\\n"

        "ユーザ:今の気持ちを教えてください。\\n"
        "oupe ec:さけびたい さけびたい の だが 下宿ではやや厳しい ので 野に放っていただきたい\\n\\n"

        f"ユーザ:{m.text.strip()}\\n"
        "oupe ec:"
    )

    ids = tok(prompt, return_tensors="pt").to(device)
    out = model.generate(
        **ids,
        max_new_tokens=36,
        do_sample=True,
        temperature=1.3,
        top_p=0.80,
        repetition_penalty=1.05,
        eos_token_id=tok.encode("\\n")[0],   
    )

    raw   = tok.decode(out[0])
    generated = raw[len(prompt):]              
    reply = generated.split("</s>")[0].strip() 

    return {"reply": reply}