文件预览

classify.py

查看 Alibabacloud Emr Starrocks Assistant 技能包中的文件内容。

文件内容

scripts/sr_connect/classify.py

"""Classify SQL statements as READ vs non-READ for the srsql execution gate.

Uses sqlglot with dialect="starrocks". Conservative: only inspects AST root
node class names; does not depend on deep parsing accuracy.

When sqlglot cannot parse a statement (StarRocks-specific syntax it doesn't
model yet, or genuine syntax errors), falls back to a leading-keyword match.
The fallback is best-effort and marked with a warning so the gate / Skill
can surface it to the user.

Verdict semantics for the gate:
    READ        -> executes directly
    everything else (incl. UNKNOWN) -> requires --yes
"""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum

import sqlglot
import sqlglot.errors
import sqlglot.expressions as exp


DIALECT = "starrocks"


class Verdict(str, Enum):
    READ = "read"
    WRITE_DML = "write_dml"   # INSERT / UPDATE / DELETE / MERGE / LOAD
    DDL = "ddl"               # CREATE / ALTER / DROP / TRUNCATE
    ADMIN = "admin"           # GRANT / REVOKE / CREATE USER / SET CONFIG ...
    SESSION = "session"       # SET / USE / COMMIT / ROLLBACK
    UNKNOWN = "unknown"       # parse failure and no keyword match


# AST root class name -> Verdict.
# Stable surface: depends only on root-level node identity, never deep AST.
_ROOT_CLASS_TO_VERDICT: dict[str, Verdict] = {
    # Read
    "Select": Verdict.READ,
    "Union": Verdict.READ,
    "Intersect": Verdict.READ,
    "Except": Verdict.READ,
    "With": Verdict.READ,
    "Subquery": Verdict.READ,
    "Show": Verdict.READ,
    "Describe": Verdict.READ,
    "Pragma": Verdict.READ,
    # Write (DML)
    "Insert": Verdict.WRITE_DML,
    "Update": Verdict.WRITE_DML,
    "Delete": Verdict.WRITE_DML,
    "Merge": Verdict.WRITE_DML,
    # DDL
    "Create": Verdict.DDL,
    "Drop": Verdict.DDL,
    "Alter": Verdict.DDL,
    "AlterColumn": Verdict.DDL,
    "TruncateTable": Verdict.DDL,
    # Admin (sqlglot has dedicated classes for some)
    "Kill": Verdict.ADMIN,
    # Session-scoped
    "Set": Verdict.SESSION,
    "Use": Verdict.SESSION,
    "Transaction": Verdict.SESSION,
    "Commit": Verdict.SESSION,
    "Rollback": Verdict.SESSION,
}


# Leading-keyword -> Verdict, used by:
#   (a) exp.Command (sqlglot's catch-all for un-modeled statements)
#   (b) Parse-failure fallback (sqlglot couldn't parse at all)
# Ordering is critical: longer / more-specific prefixes MUST come before
# shorter ones, because the first prefix match wins.
_COMMAND_VERBS: list[tuple[str, Verdict]] = [
    # Two-word: user/role admin
    ("CREATE USER", Verdict.ADMIN),
    ("CREATE ROLE", Verdict.ADMIN),
    ("DROP USER", Verdict.ADMIN),
    ("DROP ROLE", Verdict.ADMIN),
    ("ALTER USER", Verdict.ADMIN),
    ("ALTER ROLE", Verdict.ADMIN),
    ("SET PASSWORD", Verdict.ADMIN),
    ("SET DEFAULT", Verdict.ADMIN),       # SET DEFAULT ROLE
    ("SET ROLE", Verdict.SESSION),
    # Two-word: write/refresh
    ("SUBMIT TASK", Verdict.WRITE_DML),
    ("REFRESH MATERIALIZED", Verdict.WRITE_DML),
    # Single-word admin
    ("GRANT", Verdict.ADMIN),
    ("REVOKE", Verdict.ADMIN),
    ("ADMIN", Verdict.ADMIN),             # ADMIN SET CONFIG / REPAIR / ...
    ("CANCEL", Verdict.ADMIN),
    ("PAUSE", Verdict.ADMIN),
    ("RESUME", Verdict.ADMIN),
    ("STOP", Verdict.ADMIN),
    ("KILL", Verdict.ADMIN),
    ("ANALYZE", Verdict.ADMIN),
    ("BACKUP", Verdict.ADMIN),
    ("RESTORE", Verdict.ADMIN),
    ("RECOVER", Verdict.ADMIN),
    ("INSTALL", Verdict.ADMIN),
    ("UNINSTALL", Verdict.ADMIN),
    # Write (job-level commands)
    ("LOAD", Verdict.WRITE_DML),
    ("REFRESH", Verdict.WRITE_DML),
    # Read
    ("EXPLAIN", Verdict.READ),
    ("DESCRIBE", Verdict.READ),
    ("DESC", Verdict.READ),
    ("HELP", Verdict.READ),
    ("SHOW", Verdict.READ),
    ("WITH", Verdict.READ),
    ("SELECT", Verdict.READ),
    # Session
    ("USE", Verdict.SESSION),
    ("SET", Verdict.SESSION),
    # DDL fallback (must come AFTER more-specific user/role variants above)
    ("CREATE", Verdict.DDL),
    ("DROP", Verdict.DDL),
    ("ALTER", Verdict.DDL),
    ("TRUNCATE", Verdict.DDL),
    ("RENAME", Verdict.DDL),
    # DML fallback
    ("INSERT", Verdict.WRITE_DML),
    ("UPDATE", Verdict.WRITE_DML),
    ("DELETE", Verdict.WRITE_DML),
    ("MERGE", Verdict.WRITE_DML),
]


_TYPE_LABEL: dict[str, str] = {
    "Select": "SELECT",
    "Union": "UNION",
    "Intersect": "INTERSECT",
    "Except": "EXCEPT",
    "With": "WITH (CTE)",
    "Subquery": "SUBQUERY",
    "Show": "SHOW",
    "Describe": "DESCRIBE",
    "Pragma": "PRAGMA",
    "Insert": "INSERT",
    "Update": "UPDATE",
    "Delete": "DELETE",
    "Merge": "MERGE",
    "Create": "CREATE",
    "Drop": "DROP",
    "Alter": "ALTER",
    "AlterColumn": "ALTER COLUMN",
    "TruncateTable": "TRUNCATE TABLE",
    "Kill": "KILL",
    "Set": "SET",
    "Use": "USE",
    "Transaction": "TRANSACTION",
    "Commit": "COMMIT",
    "Rollback": "ROLLBACK",
}


@dataclass(frozen=True)
class Classification:
    verdict: Verdict
    statement_type: str
    target: str | None = None
    warning: str | None = None

    @property
    def is_read_only(self) -> bool:
        return self.verdict == Verdict.READ


def classify(sql: str) -> list[Classification]:
    """Classify each top-level statement in `sql`. Multi-statement supported."""
    if not sql or not sql.strip():
        return [Classification(
            verdict=Verdict.UNKNOWN,
            statement_type="EMPTY",
            warning="no SQL statement detected",
        )]
    try:
        asts = sqlglot.parse(sql, dialect=DIALECT)
    except sqlglot.errors.ParseError as e:
        return [_fallback_keyword_classify(sql, parse_error=str(e))]
    if not asts:
        return [_fallback_keyword_classify(sql, parse_error="no statement detected (comments only?)")]
    results = []
    for ast in asts:
        if ast is None:
            results.append(_fallback_keyword_classify(sql, parse_error="statement could not be parsed"))
        else:
            results.append(_classify_node(ast))
    return results


def classify_one(sql: str) -> Classification:
    """Single decision for a SQL blob. Multi-statement aggregates to the
    most permission-demanding verdict (any non-READ wins)."""
    results = classify(sql)
    if not results:
        return Classification(
            verdict=Verdict.UNKNOWN,
            statement_type="EMPTY",
            warning="no SQL statement detected",
        )
    if len(results) == 1:
        return results[0]
    non_reads = [r for r in results if not r.is_read_only]
    if non_reads:
        first = non_reads[0]
        return Classification(
            verdict=first.verdict,
            statement_type=f"MIXED ({len(results)} stmts, first non-READ: {first.statement_type})",
            target=first.target,
            warning=f"batch contains {len(results)} statements; gating on first non-READ",
        )
    return Classification(
        verdict=Verdict.READ,
        statement_type=f"MULTI READ ({len(results)} stmts)",
    )


def _classify_node(ast: exp.Expression) -> Classification:
    cls_name = type(ast).__name__
    verdict = _ROOT_CLASS_TO_VERDICT.get(cls_name)
    if verdict is not None:
        return Classification(
            verdict=verdict,
            statement_type=_TYPE_LABEL.get(cls_name, cls_name.upper()),
            target=_extract_target(ast),
        )
    if isinstance(ast, exp.Command):
        verb = (ast.this or "").upper().strip()
        rest = ""
        if ast.expression is not None:
            rest = getattr(ast.expression, "name", "") or ""
        full = f"{verb} {rest}".upper().strip()
        for prefix, v in _COMMAND_VERBS:
            if full.startswith(prefix):
                return Classification(verdict=v, statement_type=prefix)
        return Classification(
            verdict=Verdict.UNKNOWN,
            statement_type=verb or "UNKNOWN",
            warning=f"unrecognized statement verb: {verb!r}",
        )
    return Classification(
        verdict=Verdict.UNKNOWN,
        statement_type=cls_name.upper(),
        warning=f"unrecognized AST root: {cls_name}",
    )


def _fallback_keyword_classify(sql: str, parse_error: str) -> Classification:
    """Best-effort classification by leading keyword when sqlglot fails."""
    leader = _leading_keyword(sql)
    if not leader:
        return Classification(
            verdict=Verdict.UNKNOWN,
            statement_type="EMPTY",
            warning=parse_error,
        )
    for prefix, v in _COMMAND_VERBS:
        if leader.startswith(prefix):
            return Classification(
                verdict=v,
                statement_type=prefix,
                warning=(
                    f"sqlglot could not parse; classified by leading keyword "
                    f"(best-effort). Reason: {parse_error[:120]}"
                ),
            )
    one_word = leader.split()[0]
    return Classification(
        verdict=Verdict.UNKNOWN,
        statement_type=one_word,
        warning=f"unrecognized leading keyword {one_word!r}; parse error: {parse_error[:120]}",
    )


def _leading_keyword(sql: str) -> str:
    """Strip leading comments + whitespace, return up to 3 leading words uppercased."""
    s = sql
    while True:
        s = s.lstrip()
        if s.startswith("--"):
            nl = s.find("\n")
            s = s[nl + 1:] if nl >= 0 else ""
        elif s.startswith("#"):
            nl = s.find("\n")
            s = s[nl + 1:] if nl >= 0 else ""
        elif s.startswith("/*"):
            end = s.find("*/")
            s = s[end + 2:] if end >= 0 else ""
        else:
            break
    if not s:
        return ""
    parts = s.split(None, 3)
    return " ".join(parts[:3]).upper()


def _extract_target(ast: exp.Expression) -> str | None:
    """Best-effort target table/object name for display. Returns None when
    the target is not a simple identifier (e.g. SELECT subqueries)."""
    target_node = getattr(ast, "this", None)
    if target_node is None:
        return None
    if isinstance(target_node, exp.Table):
        return target_node.sql(dialect=DIALECT)
    if isinstance(target_node, exp.Schema):
        inner = target_node.this
        if isinstance(inner, exp.Table):
            return inner.sql(dialect=DIALECT)
    if hasattr(target_node, "sql"):
        try:
            s = target_node.sql(dialect=DIALECT)
            if len(s) <= 100 and "\n" not in s:
                return s
        except Exception:
            pass
    return None