#!/usr/bin/env python3
"""
ZeroKVault — Offline Public Key Generator
==========================================
Derives the same EC P-256 public key that ZeroKVault generates in your browser
from a 12-word BIP-39 seed phrase, on a trusted offline machine.

Requirements
------------
  pip install cryptography

Usage
-----
  python3 generate_key.py                         # prompts for words interactively
  python3 generate_key.py word1 word2 ... word12  # words as command-line arguments

Output
------
  Public key (JWK)  — paste this into ZeroKVault's "Import Public Key" form
  Fingerprint       — verify this matches what appears in your ZeroKVault key list

Security note
-------------
  This script derives the public key only. Your 12 secret words and the private
  key never leave this machine. Run this on a trusted, preferably air-gapped device.
"""

import hashlib
import json
import sys

# ── Python version check ───────────────────────────────────────────────────────

if sys.version_info < (3, 8):
    print("Error: Python 3.8 or newer is required.", file=sys.stderr)
    sys.exit(1)

# ── Dependency check ───────────────────────────────────────────────────────────

try:
    from cryptography.hazmat.primitives import hashes
    from cryptography.hazmat.primitives.asymmetric import ec
    from cryptography.hazmat.primitives.kdf.hkdf import HKDF
except ImportError:
    print(
        "Error: 'cryptography' package not found.\n"
        "Install it with:  pip install cryptography",
        file=sys.stderr,
    )
    sys.exit(1)

import base64


# ── Core functions ─────────────────────────────────────────────────────────────


def _base64url_encode(data: bytes) -> str:
    """Base64url-encode bytes without padding characters."""
    return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")


def derive_public_key(mnemonic: str) -> dict:
    """
    Derive an EC P-256 public key JWK from a BIP-39 mnemonic.

    Mirrors the pipeline in src/lib/crypto/derive-key.ts:
      1. BIP-39 seed  : PBKDF2-HMAC-SHA512(mnemonic, b"mnemonic", 2048 iter, 64 bytes)
      2. HKDF-SHA256  : (seed, salt=32 zero bytes, info=b"zerokvault-ecdh-p256-v1", 32 bytes)
      3. P-256 scalar : load the 32-byte output as the private key integer d
      4. Public key   : derive from scalar, export as JWK {kty, crv, x, y}
    """
    # Step 1: BIP-39 → 64-byte seed (standard derivation, empty passphrase)
    seed = hashlib.pbkdf2_hmac(
        "sha512",
        mnemonic.encode("utf-8"),
        b"mnemonic",
        2048,
        dklen=64,
    )

    # Step 2: HKDF-SHA256 → 32-byte private key scalar
    d_bytes = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=bytes(32),  # 32 zero bytes, matching derive-key.ts
        info=b"zerokvault-ecdh-p256-v1",
    ).derive(seed)

    # Step 3: Load as P-256 private key
    d_int = int.from_bytes(d_bytes, "big")
    private_key = ec.derive_private_key(d_int, ec.SECP256R1())

    # Step 4: Export public key coordinates as base64url
    pub_numbers = private_key.public_key().public_numbers()
    x_b64 = _base64url_encode(pub_numbers.x.to_bytes(32, "big"))
    y_b64 = _base64url_encode(pub_numbers.y.to_bytes(32, "big"))

    return {"kty": "EC", "crv": "P-256", "x": x_b64, "y": y_b64}


def compute_fingerprint(jwk: dict) -> str:
    """
    Compute the key fingerprint used by ZeroKVault.

    Matches fingerprintPublicKey() in src/lib/crypto/derive-key.ts:
      SHA-256 of canonical JSON {"crv":..,"kty":..,"x":..,"y":..} → first 16 hex chars.
    """
    canonical = json.dumps(
        {"crv": jwk["crv"], "kty": jwk["kty"], "x": jwk["x"], "y": jwk["y"]},
        separators=(",", ":"),
    )
    return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]


# ── Entry point ────────────────────────────────────────────────────────────────


def main() -> None:
    if len(sys.argv) > 1:
        mnemonic = " ".join(sys.argv[1:])
    else:
        print("Enter your 12 secret words separated by spaces:")
        mnemonic = input("> ")

    # Normalize: lowercase, collapse whitespace
    mnemonic = " ".join(mnemonic.strip().lower().split())

    words = mnemonic.split()
    if len(words) != 12:
        print(
            f"Error: expected exactly 12 words, got {len(words)}.",
            file=sys.stderr,
        )
        sys.exit(1)

    jwk = derive_public_key(mnemonic)
    fingerprint = compute_fingerprint(jwk)

    print()
    print("── Public Key (JWK) " + "─" * 52)
    print(json.dumps(jwk, indent=2))
    print()
    print("── Fingerprint " + "─" * 57)
    print(fingerprint)
    print()
    print("Next steps:")
    print("  1. Copy the Public Key (JWK) above.")
    print("  2. In ZeroKVault, go to Keys → Add Key → Import tab.")
    print("  3. Paste the JWK and give the key a name.")
    print("  4. After import, verify the fingerprint matches what appears in your key list.")


if __name__ == "__main__":
    main()
