#!/usr/bin/env python3
"""
Bin-picking grasp-success experiment in PyBullet (headless / DIRECT).

A controlled parallel-jaw gripper (two friction fingers driven by position
constraints) attempts top-down antipodal grasps on objects dropped into a
tray, across increasing clutter. We measure grasp success rate (GSR) per
clutter level — the canonical "does it survive a pile?" question that decides
whether a picker ships.

Original, self-contained simulation: object set comes from pybullet_data plus
parametric primitives, so anyone can reproduce it. Numbers are OUR measured GSR
in this setup — not a reproduction of any single paper's protocol. The point is
the *trend*: how grasp reliability degrades from a single object to a pile.

Usage:  python3 bin_picking_experiment.py --trials 80 --seed 0 --out results.json
"""
import argparse, json, math, os, random, time
import numpy as np
import pybullet as p
import pybullet_data

CLUTTER_LEVELS = [1, 2, 4, 8]
LIFT_DZ = 0.10            # object must rise >= 10 cm to count as a successful grasp
OPEN_W = 0.14            # finger opening before descent (m)
FINGER = [0.006, 0.022, 0.045]   # half-extents: thin in x, deep in z
Z_TABLE = 0.0


def reset_scene():
    p.resetSimulation()
    p.setGravity(0, 0, -9.81)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    p.setPhysicsEngineParameter(numSolverIterations=80)
    p.loadURDF("plane.urdf")
    # shallow tray
    th, h, T = 0.01, 0.05, 0.22
    for sx, sy, x, y in [(T, th, 0, T), (T, th, 0, -T), (th, T, T, 0), (th, T, -T, 0)]:
        c = p.createCollisionShape(p.GEOM_BOX, halfExtents=[sx, sy, h])
        p.createMultiBody(0, c, basePosition=[x, y, h])


def spawn_object(rng):
    d = pybullet_data.getDataPath()
    x, y = rng.uniform(-0.08, 0.08), rng.uniform(-0.08, 0.08)
    z = rng.uniform(0.12, 0.22)
    orn = p.getQuaternionFromEuler([rng.uniform(0, 3.14) for _ in range(3)])
    kind = rng.choice(["box", "box", "cylinder", "sphere", "duck"])
    if kind == "duck":
        b = p.loadURDF(os.path.join(d, "duck_vhacd.urdf"), [x, y, z], orn, globalScaling=0.7)
    elif kind == "box":
        s = [rng.uniform(0.015, 0.028) for _ in range(3)]
        b = p.createMultiBody(0.08, p.createCollisionShape(p.GEOM_BOX, halfExtents=s),
                              basePosition=[x, y, z], baseOrientation=orn)
    elif kind == "cylinder":
        b = p.createMultiBody(0.08, p.createCollisionShape(
            p.GEOM_CYLINDER, radius=rng.uniform(0.013, 0.025), height=rng.uniform(0.03, 0.06)),
            basePosition=[x, y, z], baseOrientation=orn)
    else:
        b = p.createMultiBody(0.08, p.createCollisionShape(
            p.GEOM_SPHERE, radius=rng.uniform(0.016, 0.026)), basePosition=[x, y, z])
    p.changeDynamics(b, -1, lateralFriction=1.1, rollingFriction=0.001, spinningFriction=0.001)
    return b


def step(n):
    for _ in range(n):
        p.stepSimulation()


class Jaw:
    """Two friction fingers, each a small box driven by a fixed constraint."""
    def __init__(self):
        col = p.createCollisionShape(p.GEOM_BOX, halfExtents=FINGER)
        self.L = p.createMultiBody(0.05, col, basePosition=[-OPEN_W / 2, 0, 0.35])
        self.R = p.createMultiBody(0.05, col, basePosition=[OPEN_W / 2, 0, 0.35])
        for f in (self.L, self.R):
            p.changeDynamics(f, -1, lateralFriction=2.0, spinningFriction=0.01)
        self.cL = p.createConstraint(self.L, -1, -1, -1, p.JOINT_FIXED, [0, 0, 0], [0, 0, 0], [-OPEN_W / 2, 0, 0.35])
        self.cR = p.createConstraint(self.R, -1, -1, -1, p.JOINT_FIXED, [0, 0, 0], [0, 0, 0], [OPEN_W / 2, 0, 0.35])

    def move(self, lx, rx, z, force, n=80):
        p.changeConstraint(self.cL, [lx, self._y, z], maxForce=force)
        p.changeConstraint(self.cR, [rx, self._y, z], maxForce=force)
        step(n)

    _y = 0.0

    def grasp(self, body):
        pos = p.getBasePositionAndOrientation(body)[0]
        cx, cy, cz = pos
        Jaw._y = cy
        # finger center placed so the fingertips reach the table beside the object,
        # straddling its full height without driving through the floor.
        zg = max(cz, FINGER[2] + 0.003)
        self.move(cx - OPEN_W / 2, cx + OPEN_W / 2, 0.34, 400, 60)   # open, above
        self.move(cx - OPEN_W / 2, cx + OPEN_W / 2, zg, 300, 100)    # descend, straddle
        self.move(cx - 0.004, cx + 0.004, zg, 22, 90)                # squeeze toward center
        z0 = p.getBasePositionAndOrientation(body)[0][2]
        self.move(cx - 0.004, cx + 0.004, 0.45, 400, 160)            # lift, hold squeeze
        step(40)
        z1 = p.getBasePositionAndOrientation(body)[0][2]
        return (z1 - z0) > LIFT_DZ


def run_condition(n_objects, trials, rng):
    succ = 0
    for _ in range(trials):
        reset_scene()
        bodies = [spawn_object(rng) for _ in range(n_objects)]
        step(280)                                       # settle the pile
        # grasp a RANDOMLY chosen object — not the easiest top pick. This is the
        # honest declutter question: any given object may be blocked by neighbors.
        target = rng.choice(bodies)
        if Jaw().grasp(target):
            succ += 1
    return succ / trials


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--trials", type=int, default=80)
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--out", default="results.json")
    args = ap.parse_args()

    p.connect(p.DIRECT)
    rng = random.Random(args.seed)
    np.random.seed(args.seed)

    results, t0 = {}, time.time()
    for n in CLUTTER_LEVELS:
        gsr = run_condition(n, args.trials, rng)
        results[n] = round(gsr, 4)
        print(f"clutter={n:>2}  trials={args.trials}  GSR={gsr*100:5.1f}%", flush=True)
    p.disconnect()

    out = {
        "experiment": "pybullet_bin_picking_gsr_vs_clutter",
        "gripper": "parallel-jaw, two friction fingers (floating)",
        "object_set": ["box", "cylinder", "sphere", "duck_vhacd"],
        "trials_per_condition": args.trials,
        "seed": args.seed,
        "lift_success_dz_m": LIFT_DZ,
        "clutter_levels": CLUTTER_LEVELS,
        "gsr_by_clutter": results,
        "wall_time_s": round(time.time() - t0, 1),
    }
    with open(args.out, "w") as f:
        json.dump(out, f, indent=2)
    print("wrote", args.out)


if __name__ == "__main__":
    main()
