Bootstrap OpenBitdo clean-room SDK and reliability milestone

This commit is contained in:
2026-02-27 20:43:34 -05:00
commit d5afadf560
46 changed files with 3652 additions and 0 deletions

View File

@@ -0,0 +1,76 @@
[package]
name = "bitdo_proto"
version = "0.1.0"
edition = "2021"
license = "MIT"
build = "build.rs"
[features]
default = ["hidapi-backend"]
hidapi-backend = ["dep:hidapi"]
[dependencies]
thiserror = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
hex = { workspace = true }
hidapi = { version = "2.6", optional = true }
[build-dependencies]
csv = "1.3"
[dev-dependencies]
serde_json = { workspace = true }
hex = { workspace = true }
[[test]]
name = "frame_roundtrip"
path = "../../tests/frame_roundtrip.rs"
[[test]]
name = "parser_rejection"
path = "../../tests/parser_rejection.rs"
[[test]]
name = "retry_timeout"
path = "../../tests/retry_timeout.rs"
[[test]]
name = "pid_matrix_coverage"
path = "../../tests/pid_matrix_coverage.rs"
[[test]]
name = "capability_gating"
path = "../../tests/capability_gating.rs"
[[test]]
name = "profile_serialization"
path = "../../tests/profile_serialization.rs"
[[test]]
name = "mode_switch_readback"
path = "../../tests/mode_switch_readback.rs"
[[test]]
name = "boot_safety"
path = "../../tests/boot_safety.rs"
[[test]]
name = "firmware_chunk"
path = "../../tests/firmware_chunk.rs"
[[test]]
name = "cleanroom_guard"
path = "../../tests/cleanroom_guard.rs"
[[test]]
name = "hardware_smoke"
path = "../../tests/hardware_smoke.rs"
[[test]]
name = "error_codes"
path = "../../tests/error_codes.rs"
[[test]]
name = "diag_probe"
path = "../../tests/diag_probe.rs"

View File

@@ -0,0 +1,121 @@
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
fn main() {
let manifest_dir =
PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("missing CARGO_MANIFEST_DIR"));
let spec_dir = manifest_dir.join("../../../spec");
let out_dir = PathBuf::from(env::var("OUT_DIR").expect("missing OUT_DIR"));
let pid_csv = spec_dir.join("pid_matrix.csv");
let command_csv = spec_dir.join("command_matrix.csv");
println!("cargo:rerun-if-changed={}", pid_csv.display());
println!("cargo:rerun-if-changed={}", command_csv.display());
generate_pid_registry(&pid_csv, &out_dir.join("generated_pid_registry.rs"));
generate_command_registry(&command_csv, &out_dir.join("generated_command_registry.rs"));
}
fn generate_pid_registry(csv_path: &Path, out_path: &Path) {
let mut rdr = csv::Reader::from_path(csv_path).expect("failed to open pid_matrix.csv");
let mut out = String::new();
out.push_str("pub const PID_REGISTRY: &[crate::registry::PidRegistryRow] = &[\n");
for rec in rdr.records() {
let rec = rec.expect("invalid pid csv record");
let name = rec.get(0).expect("pid_name");
let pid: u16 = rec
.get(1)
.expect("pid_decimal")
.parse()
.expect("invalid pid decimal");
let support_level = match rec.get(5).expect("support_level") {
"full" => "crate::types::SupportLevel::Full",
"detect-only" => "crate::types::SupportLevel::DetectOnly",
other => panic!("unknown support_level {other}"),
};
let protocol_family = match rec.get(6).expect("protocol_family") {
"Standard64" => "crate::types::ProtocolFamily::Standard64",
"JpHandshake" => "crate::types::ProtocolFamily::JpHandshake",
"DInput" => "crate::types::ProtocolFamily::DInput",
"DS4Boot" => "crate::types::ProtocolFamily::DS4Boot",
"Unknown" => "crate::types::ProtocolFamily::Unknown",
other => panic!("unknown protocol_family {other}"),
};
out.push_str(&format!(
" crate::registry::PidRegistryRow {{ name: \"{name}\", pid: {pid}, support_level: {support_level}, protocol_family: {protocol_family} }},\n"
));
}
out.push_str("]\n;");
fs::write(out_path, out).expect("failed writing generated_pid_registry.rs");
}
fn generate_command_registry(csv_path: &Path, out_path: &Path) {
let mut rdr = csv::Reader::from_path(csv_path).expect("failed to open command_matrix.csv");
let mut out = String::new();
out.push_str("pub const COMMAND_REGISTRY: &[crate::registry::CommandRegistryRow] = &[\n");
for rec in rdr.records() {
let rec = rec.expect("invalid command csv record");
let id = rec.get(0).expect("command_id");
let safety_class = match rec.get(1).expect("safety_class") {
"SafeRead" => "crate::types::SafetyClass::SafeRead",
"SafeWrite" => "crate::types::SafetyClass::SafeWrite",
"UnsafeBoot" => "crate::types::SafetyClass::UnsafeBoot",
"UnsafeFirmware" => "crate::types::SafetyClass::UnsafeFirmware",
other => panic!("unknown safety_class {other}"),
};
let confidence = match rec.get(2).expect("confidence") {
"confirmed" => "crate::types::CommandConfidence::Confirmed",
"inferred" => "crate::types::CommandConfidence::Inferred",
other => panic!("unknown confidence {other}"),
};
let experimental_default = rec
.get(3)
.expect("experimental_default")
.parse::<bool>()
.expect("invalid experimental_default");
let report_id = parse_u8(rec.get(4).expect("report_id"));
let request_hex = rec.get(6).expect("request_hex");
let request = hex_to_bytes(request_hex);
let expected_response = rec.get(7).expect("expected_response");
out.push_str(&format!(
" crate::registry::CommandRegistryRow {{ id: crate::command::CommandId::{id}, safety_class: {safety_class}, confidence: {confidence}, experimental_default: {experimental_default}, report_id: {report_id}, request: &{request:?}, expected_response: \"{expected_response}\" }},\n"
));
}
out.push_str("]\n;");
fs::write(out_path, out).expect("failed writing generated_command_registry.rs");
}
fn parse_u8(value: &str) -> u8 {
if let Some(stripped) = value.strip_prefix("0x") {
u8::from_str_radix(stripped, 16).expect("invalid hex u8")
} else {
value.parse::<u8>().expect("invalid u8")
}
}
fn hex_to_bytes(hex: &str) -> Vec<u8> {
let hex = hex.trim();
if hex.len() % 2 != 0 {
panic!("hex length must be even: {hex}");
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
let raw = hex.as_bytes();
for i in (0..raw.len()).step_by(2) {
let hi = (raw[i] as char)
.to_digit(16)
.unwrap_or_else(|| panic!("invalid hex: {hex}"));
let lo = (raw[i + 1] as char)
.to_digit(16)
.unwrap_or_else(|| panic!("invalid hex: {hex}"));
bytes.push(((hi << 4) | lo) as u8);
}
bytes
}

View File

@@ -0,0 +1,60 @@
use crate::types::{CommandConfidence, SafetyClass};
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum CommandId {
GetPid,
GetReportRevision,
GetMode,
GetModeAlt,
GetControllerVersion,
GetSuperButton,
SetModeDInput,
Idle,
Version,
ReadProfile,
WriteProfile,
EnterBootloaderA,
EnterBootloaderB,
EnterBootloaderC,
ExitBootloader,
FirmwareChunk,
FirmwareCommit,
}
impl CommandId {
pub const ALL: [CommandId; 17] = [
CommandId::GetPid,
CommandId::GetReportRevision,
CommandId::GetMode,
CommandId::GetModeAlt,
CommandId::GetControllerVersion,
CommandId::GetSuperButton,
CommandId::SetModeDInput,
CommandId::Idle,
CommandId::Version,
CommandId::ReadProfile,
CommandId::WriteProfile,
CommandId::EnterBootloaderA,
CommandId::EnterBootloaderB,
CommandId::EnterBootloaderC,
CommandId::ExitBootloader,
CommandId::FirmwareChunk,
CommandId::FirmwareCommit,
];
pub fn all() -> &'static [CommandId] {
&Self::ALL
}
}
#[derive(Clone, Debug)]
pub struct CommandDefinition {
pub id: CommandId,
pub safety_class: SafetyClass,
pub confidence: CommandConfidence,
pub experimental_default: bool,
pub report_id: u8,
pub request: &'static [u8],
pub expected_response: &'static str,
}

View File

@@ -0,0 +1,65 @@
use crate::command::CommandId;
use crate::types::VidPid;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum BitdoErrorCode {
Transport,
Timeout,
InvalidResponse,
MalformedResponse,
UnsupportedForPid,
ExperimentalRequired,
UnsafeCommandDenied,
UnknownPid,
InvalidInput,
UnknownCommand,
DeviceNotOpen,
}
#[derive(Debug, Error)]
pub enum BitdoError {
#[error("transport error: {0}")]
Transport(String),
#[error("timeout while waiting for device response")]
Timeout,
#[error("invalid response for {command:?}: {reason}")]
InvalidResponse { command: CommandId, reason: String },
#[error("malformed response for {command:?}: len={len}")]
MalformedResponse { command: CommandId, len: usize },
#[error("unsupported command {command:?} for PID {pid:#06x}")]
UnsupportedForPid { command: CommandId, pid: u16 },
#[error("inferred command {command:?} requires --experimental")]
ExperimentalRequired { command: CommandId },
#[error("unsafe command {command:?} requires --unsafe and --i-understand-brick-risk")]
UnsafeCommandDenied { command: CommandId },
#[error("unknown PID {0:#06x}")]
UnknownPid(u16),
#[error("invalid input: {0}")]
InvalidInput(String),
#[error("command definition not found: {0:?}")]
UnknownCommand(CommandId),
#[error("device not open for {0}")]
DeviceNotOpen(VidPid),
}
impl BitdoError {
pub fn code(&self) -> BitdoErrorCode {
match self {
BitdoError::Transport(_) => BitdoErrorCode::Transport,
BitdoError::Timeout => BitdoErrorCode::Timeout,
BitdoError::InvalidResponse { .. } => BitdoErrorCode::InvalidResponse,
BitdoError::MalformedResponse { .. } => BitdoErrorCode::MalformedResponse,
BitdoError::UnsupportedForPid { .. } => BitdoErrorCode::UnsupportedForPid,
BitdoError::ExperimentalRequired { .. } => BitdoErrorCode::ExperimentalRequired,
BitdoError::UnsafeCommandDenied { .. } => BitdoErrorCode::UnsafeCommandDenied,
BitdoError::UnknownPid(_) => BitdoErrorCode::UnknownPid,
BitdoError::InvalidInput(_) => BitdoErrorCode::InvalidInput,
BitdoError::UnknownCommand(_) => BitdoErrorCode::UnknownCommand,
BitdoError::DeviceNotOpen(_) => BitdoErrorCode::DeviceNotOpen,
}
}
}
pub type Result<T> = std::result::Result<T, BitdoError>;

View File

@@ -0,0 +1,56 @@
use crate::command::CommandId;
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Report64(pub [u8; 64]);
impl Report64 {
pub fn as_slice(&self) -> &[u8] {
&self.0
}
}
impl TryFrom<&[u8]> for Report64 {
type Error = String;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
if value.len() != 64 {
return Err(format!("expected 64 bytes, got {}", value.len()));
}
let mut arr = [0u8; 64];
arr.copy_from_slice(value);
Ok(Self(arr))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct VariableReport(pub Vec<u8>);
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct CommandFrame {
pub id: CommandId,
pub payload: Vec<u8>,
pub report_id: u8,
pub expected_response: &'static str,
}
impl CommandFrame {
pub fn encode(&self) -> Vec<u8> {
self.payload.clone()
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum ResponseStatus {
Ok,
Invalid,
Malformed,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct ResponseFrame {
pub raw: Vec<u8>,
pub status: ResponseStatus,
pub parsed_fields: BTreeMap<String, u32>,
}

View File

@@ -0,0 +1,125 @@
#![cfg(feature = "hidapi-backend")]
use crate::error::{BitdoError, Result};
use crate::transport::Transport;
use crate::types::VidPid;
use hidapi::{HidApi, HidDevice};
#[derive(Clone, Debug)]
pub struct EnumeratedDevice {
pub vid_pid: VidPid,
pub product: Option<String>,
pub manufacturer: Option<String>,
pub serial: Option<String>,
pub path: String,
}
pub fn enumerate_hid_devices() -> Result<Vec<EnumeratedDevice>> {
let api = HidApi::new().map_err(|e| BitdoError::Transport(e.to_string()))?;
let mut devices = Vec::new();
for dev in api.device_list() {
devices.push(EnumeratedDevice {
vid_pid: VidPid::new(dev.vendor_id(), dev.product_id()),
product: dev.product_string().map(ToOwned::to_owned),
manufacturer: dev.manufacturer_string().map(ToOwned::to_owned),
serial: dev.serial_number().map(ToOwned::to_owned),
path: dev.path().to_string_lossy().to_string(),
});
}
Ok(devices)
}
pub struct HidTransport {
api: Option<HidApi>,
device: Option<HidDevice>,
target: Option<VidPid>,
}
impl HidTransport {
pub fn new() -> Self {
Self {
api: None,
device: None,
target: None,
}
}
}
impl Default for HidTransport {
fn default() -> Self {
Self::new()
}
}
impl Transport for HidTransport {
fn open(&mut self, vid_pid: VidPid) -> Result<()> {
let api = HidApi::new().map_err(|e| BitdoError::Transport(e.to_string()))?;
let device = api
.open(vid_pid.vid, vid_pid.pid)
.map_err(|e| BitdoError::Transport(format!("open failed for {}: {}", vid_pid, e)))?;
self.target = Some(vid_pid);
self.device = Some(device);
self.api = Some(api);
Ok(())
}
fn close(&mut self) -> Result<()> {
self.device = None;
self.api = None;
self.target = None;
Ok(())
}
fn write(&mut self, data: &[u8]) -> Result<usize> {
let device = self
.device
.as_ref()
.ok_or_else(|| BitdoError::Transport("HID transport not open".to_owned()))?;
device
.write(data)
.map_err(|e| BitdoError::Transport(e.to_string()))
}
fn read(&mut self, len: usize, timeout_ms: u64) -> Result<Vec<u8>> {
let device = self
.device
.as_ref()
.ok_or_else(|| BitdoError::Transport("HID transport not open".to_owned()))?;
let mut buf = vec![0u8; len];
let read = device
.read_timeout(&mut buf, timeout_ms as i32)
.map_err(|e| BitdoError::Transport(e.to_string()))?;
if read == 0 {
return Err(BitdoError::Timeout);
}
buf.truncate(read);
Ok(buf)
}
fn write_feature(&mut self, data: &[u8]) -> Result<usize> {
let device = self
.device
.as_ref()
.ok_or_else(|| BitdoError::Transport("HID transport not open".to_owned()))?;
device
.send_feature_report(data)
.map_err(|e| BitdoError::Transport(e.to_string()))?;
Ok(data.len())
}
fn read_feature(&mut self, len: usize) -> Result<Vec<u8>> {
let device = self
.device
.as_ref()
.ok_or_else(|| BitdoError::Transport("HID transport not open".to_owned()))?;
let mut buf = vec![0u8; len];
let read = device
.get_feature_report(&mut buf)
.map_err(|e| BitdoError::Transport(e.to_string()))?;
if read == 0 {
return Err(BitdoError::Timeout);
}
buf.truncate(read);
Ok(buf)
}
}

View File

@@ -0,0 +1,30 @@
mod command;
mod error;
mod frame;
#[cfg(feature = "hidapi-backend")]
mod hid_transport;
mod profile;
mod registry;
mod session;
mod transport;
mod types;
pub use command::{CommandDefinition, CommandId};
pub use error::{BitdoError, BitdoErrorCode, Result};
pub use frame::{CommandFrame, Report64, ResponseFrame, ResponseStatus, VariableReport};
#[cfg(feature = "hidapi-backend")]
pub use hid_transport::{enumerate_hid_devices, EnumeratedDevice, HidTransport};
pub use profile::ProfileBlob;
pub use registry::{
command_registry, device_profile_for, find_command, find_pid, pid_registry, CommandRegistryRow,
PidRegistryRow,
};
pub use session::{
validate_response, CommandExecutionReport, DeviceSession, DiagCommandStatus, DiagProbeResult,
FirmwareTransferReport, IdentifyResult, ModeState, RetryPolicy, SessionConfig, TimeoutProfile,
};
pub use transport::{MockTransport, Transport};
pub use types::{
CommandConfidence, DeviceProfile, PidCapability, ProtocolFamily, SafetyClass, SupportEvidence,
SupportLevel, VidPid,
};

View File

@@ -0,0 +1,63 @@
use crate::error::{BitdoError, Result};
use serde::{Deserialize, Serialize};
const MAGIC: &[u8; 4] = b"BDP1";
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct ProfileBlob {
pub slot: u8,
pub payload: Vec<u8>,
}
impl ProfileBlob {
pub fn to_bytes(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + 1 + 2 + self.payload.len() + 4);
out.extend_from_slice(MAGIC);
out.push(self.slot);
out.extend_from_slice(&(self.payload.len() as u16).to_le_bytes());
out.extend_from_slice(&self.payload);
let checksum = checksum(&out[4..]);
out.extend_from_slice(&checksum.to_le_bytes());
out
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < 11 {
return Err(BitdoError::InvalidInput(
"profile blob too short".to_owned(),
));
}
if &data[0..4] != MAGIC {
return Err(BitdoError::InvalidInput("invalid profile magic".to_owned()));
}
let slot = data[4];
let len = u16::from_le_bytes([data[5], data[6]]) as usize;
let payload_end = 7 + len;
if payload_end + 4 > data.len() {
return Err(BitdoError::InvalidInput(
"profile length exceeds blob size".to_owned(),
));
}
let payload = data[7..payload_end].to_vec();
let expected = u32::from_le_bytes([
data[payload_end],
data[payload_end + 1],
data[payload_end + 2],
data[payload_end + 3],
]);
let actual = checksum(&data[4..payload_end]);
if expected != actual {
return Err(BitdoError::InvalidInput(format!(
"checksum mismatch expected={expected:#x} actual={actual:#x}"
)));
}
Ok(Self { slot, payload })
}
}
fn checksum(data: &[u8]) -> u32 {
data.iter().fold(0u32, |acc, b| acc.wrapping_add(*b as u32))
}

View File

@@ -0,0 +1,82 @@
use crate::command::CommandId;
use crate::types::{
CommandConfidence, DeviceProfile, PidCapability, ProtocolFamily, SafetyClass, SupportEvidence,
SupportLevel, VidPid,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct PidRegistryRow {
pub name: &'static str,
pub pid: u16,
pub support_level: SupportLevel,
pub protocol_family: ProtocolFamily,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct CommandRegistryRow {
pub id: CommandId,
pub safety_class: SafetyClass,
pub confidence: CommandConfidence,
pub experimental_default: bool,
pub report_id: u8,
pub request: &'static [u8],
pub expected_response: &'static str,
}
include!(concat!(env!("OUT_DIR"), "/generated_pid_registry.rs"));
include!(concat!(env!("OUT_DIR"), "/generated_command_registry.rs"));
pub fn pid_registry() -> &'static [PidRegistryRow] {
PID_REGISTRY
}
pub fn command_registry() -> &'static [CommandRegistryRow] {
COMMAND_REGISTRY
}
pub fn find_pid(pid: u16) -> Option<&'static PidRegistryRow> {
PID_REGISTRY.iter().find(|row| row.pid == pid)
}
pub fn find_command(id: CommandId) -> Option<&'static CommandRegistryRow> {
COMMAND_REGISTRY.iter().find(|row| row.id == id)
}
pub fn default_capability_for(
support_level: SupportLevel,
_protocol_family: ProtocolFamily,
) -> PidCapability {
match support_level {
SupportLevel::Full => PidCapability::full(),
SupportLevel::DetectOnly => PidCapability::identify_only(),
}
}
pub fn default_evidence_for(support_level: SupportLevel) -> SupportEvidence {
match support_level {
SupportLevel::Full => SupportEvidence::Confirmed,
SupportLevel::DetectOnly => SupportEvidence::Inferred,
}
}
pub fn device_profile_for(vid_pid: VidPid) -> DeviceProfile {
if let Some(row) = find_pid(vid_pid.pid) {
DeviceProfile {
vid_pid,
name: row.name.to_owned(),
support_level: row.support_level,
protocol_family: row.protocol_family,
capability: default_capability_for(row.support_level, row.protocol_family),
evidence: default_evidence_for(row.support_level),
}
} else {
DeviceProfile {
vid_pid,
name: "PID_UNKNOWN".to_owned(),
support_level: SupportLevel::DetectOnly,
protocol_family: ProtocolFamily::Unknown,
capability: PidCapability::identify_only(),
evidence: SupportEvidence::Untested,
}
}
}

View File

@@ -0,0 +1,715 @@
use crate::command::CommandId;
use crate::error::{BitdoError, BitdoErrorCode, Result};
use crate::frame::{CommandFrame, ResponseFrame, ResponseStatus};
use crate::profile::ProfileBlob;
use crate::registry::{device_profile_for, find_command, find_pid, CommandRegistryRow};
use crate::transport::Transport;
use crate::types::{
CommandConfidence, DeviceProfile, PidCapability, ProtocolFamily, SafetyClass, SupportEvidence,
SupportLevel, VidPid,
};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::thread;
use std::time::Duration;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RetryPolicy {
pub max_attempts: u8,
pub backoff_ms: u64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
backoff_ms: 10,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TimeoutProfile {
pub probe_ms: u64,
pub io_ms: u64,
pub firmware_ms: u64,
}
impl Default for TimeoutProfile {
fn default() -> Self {
Self {
probe_ms: 200,
io_ms: 400,
firmware_ms: 1_200,
}
}
}
#[derive(Clone, Debug)]
pub struct SessionConfig {
pub retry_policy: RetryPolicy,
pub timeout_profile: TimeoutProfile,
pub allow_unsafe: bool,
pub brick_risk_ack: bool,
pub experimental: bool,
pub trace_enabled: bool,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
retry_policy: RetryPolicy::default(),
timeout_profile: TimeoutProfile::default(),
allow_unsafe: false,
brick_risk_ack: false,
experimental: false,
trace_enabled: true,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CommandExecutionReport {
pub command: CommandId,
pub attempts: u8,
pub validator: String,
pub status: ResponseStatus,
pub bytes_written: usize,
pub bytes_read: usize,
pub error_code: Option<BitdoErrorCode>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DiagCommandStatus {
pub command: CommandId,
pub ok: bool,
pub error_code: Option<BitdoErrorCode>,
pub detail: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DiagProbeResult {
pub target: VidPid,
pub profile_name: String,
pub support_level: SupportLevel,
pub protocol_family: ProtocolFamily,
pub capability: PidCapability,
pub evidence: SupportEvidence,
pub transport_ready: bool,
pub command_checks: Vec<DiagCommandStatus>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct IdentifyResult {
pub target: VidPid,
pub profile_name: String,
pub support_level: SupportLevel,
pub protocol_family: ProtocolFamily,
pub capability: PidCapability,
pub evidence: SupportEvidence,
pub detected_pid: Option<u16>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModeState {
pub mode: u8,
pub source: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FirmwareTransferReport {
pub bytes_total: usize,
pub chunk_size: usize,
pub chunks_sent: usize,
pub dry_run: bool,
}
pub struct DeviceSession<T: Transport> {
transport: T,
target: VidPid,
profile: DeviceProfile,
config: SessionConfig,
trace: Vec<CommandExecutionReport>,
last_execution: Option<CommandExecutionReport>,
}
impl<T: Transport> DeviceSession<T> {
pub fn new(mut transport: T, target: VidPid, config: SessionConfig) -> Result<Self> {
transport.open(target)?;
let profile = device_profile_for(target);
Ok(Self {
transport,
target,
profile,
config,
trace: Vec::new(),
last_execution: None,
})
}
pub fn profile(&self) -> &DeviceProfile {
&self.profile
}
pub fn trace(&self) -> &[CommandExecutionReport] {
&self.trace
}
pub fn last_execution_report(&self) -> Option<&CommandExecutionReport> {
self.last_execution.as_ref()
}
pub fn close(&mut self) -> Result<()> {
self.transport.close()
}
pub fn into_transport(self) -> T {
self.transport
}
pub fn identify(&mut self) -> Result<IdentifyResult> {
let detected_pid = match self.send_command(CommandId::GetPid, None) {
Ok(resp) => resp
.parsed_fields
.get("detected_pid")
.copied()
.map(|v| v as u16),
Err(_) => None,
};
let profile_row = detected_pid.and_then(find_pid);
let mut profile = self.profile.clone();
if let Some(row) = profile_row {
profile = device_profile_for(VidPid::new(self.target.vid, row.pid));
}
Ok(IdentifyResult {
target: self.target,
profile_name: profile.name,
support_level: profile.support_level,
protocol_family: profile.protocol_family,
capability: profile.capability,
evidence: profile.evidence,
detected_pid,
})
}
pub fn diag_probe(&mut self) -> DiagProbeResult {
let checks = [
CommandId::GetPid,
CommandId::GetReportRevision,
CommandId::GetMode,
CommandId::GetControllerVersion,
]
.iter()
.map(|cmd| match self.send_command(*cmd, None) {
Ok(_) => DiagCommandStatus {
command: *cmd,
ok: true,
error_code: None,
detail: "ok".to_owned(),
},
Err(err) => DiagCommandStatus {
command: *cmd,
ok: false,
error_code: Some(err.code()),
detail: err.to_string(),
},
})
.collect::<Vec<_>>();
DiagProbeResult {
target: self.target,
profile_name: self.profile.name.clone(),
support_level: self.profile.support_level,
protocol_family: self.profile.protocol_family,
capability: self.profile.capability,
evidence: self.profile.evidence,
transport_ready: true,
command_checks: checks,
}
}
pub fn get_mode(&mut self) -> Result<ModeState> {
let resp = self.send_command(CommandId::GetMode, None)?;
if let Some(mode) = resp.parsed_fields.get("mode").copied() {
return Ok(ModeState {
mode: mode as u8,
source: "GetMode".to_owned(),
});
}
let resp = self.send_command(CommandId::GetModeAlt, None)?;
let mode = resp.parsed_fields.get("mode").copied().unwrap_or_default() as u8;
Ok(ModeState {
mode,
source: "GetModeAlt".to_owned(),
})
}
pub fn set_mode(&mut self, mode: u8) -> Result<ModeState> {
let row = self.ensure_command_allowed(CommandId::SetModeDInput)?;
let mut payload = row.request.to_vec();
if payload.len() < 5 {
return Err(BitdoError::InvalidInput(
"SetModeDInput payload shorter than expected".to_owned(),
));
}
payload[4] = mode;
self.send_row(row, Some(&payload))?;
self.get_mode()
}
pub fn read_profile(&mut self, slot: u8) -> Result<ProfileBlob> {
let row = self.ensure_command_allowed(CommandId::ReadProfile)?;
let mut payload = row.request.to_vec();
if payload.len() > 3 {
payload[3] = slot;
}
let resp = self.send_row(row, Some(&payload))?;
Ok(ProfileBlob {
slot,
payload: resp.raw,
})
}
pub fn write_profile(&mut self, slot: u8, profile: &ProfileBlob) -> Result<()> {
let row = self.ensure_command_allowed(CommandId::WriteProfile)?;
let mut payload = row.request.to_vec();
if payload.len() > 3 {
payload[3] = slot;
}
let serialized = profile.to_bytes();
let copy_len = (payload.len().saturating_sub(8)).min(serialized.len());
if copy_len > 0 {
payload[8..8 + copy_len].copy_from_slice(&serialized[..copy_len]);
}
self.send_row(row, Some(&payload))?;
Ok(())
}
pub fn enter_bootloader(&mut self) -> Result<()> {
self.send_command(CommandId::EnterBootloaderA, None)?;
self.send_command(CommandId::EnterBootloaderB, None)?;
self.send_command(CommandId::EnterBootloaderC, None)?;
Ok(())
}
pub fn firmware_transfer(
&mut self,
image: &[u8],
chunk_size: usize,
dry_run: bool,
) -> Result<FirmwareTransferReport> {
if chunk_size == 0 {
return Err(BitdoError::InvalidInput(
"chunk size must be greater than zero".to_owned(),
));
}
let chunk_count = image.len().div_ceil(chunk_size);
if dry_run {
return Ok(FirmwareTransferReport {
bytes_total: image.len(),
chunk_size,
chunks_sent: chunk_count,
dry_run,
});
}
let row = self.ensure_command_allowed(CommandId::FirmwareChunk)?;
for chunk in image.chunks(chunk_size) {
let mut payload = row.request.to_vec();
let offset = 4;
let copy_len = chunk.len().min(payload.len().saturating_sub(offset));
if copy_len > 0 {
payload[offset..offset + copy_len].copy_from_slice(&chunk[..copy_len]);
}
self.send_row(row, Some(&payload))?;
}
self.send_command(CommandId::FirmwareCommit, None)?;
Ok(FirmwareTransferReport {
bytes_total: image.len(),
chunk_size,
chunks_sent: chunk_count,
dry_run,
})
}
pub fn exit_bootloader(&mut self) -> Result<()> {
self.send_command(CommandId::ExitBootloader, None)?;
Ok(())
}
pub fn send_command(
&mut self,
command: CommandId,
override_payload: Option<&[u8]>,
) -> Result<ResponseFrame> {
let row = self.ensure_command_allowed(command)?;
self.send_row(row, override_payload)
}
fn send_row(
&mut self,
row: &CommandRegistryRow,
override_payload: Option<&[u8]>,
) -> Result<ResponseFrame> {
let payload = override_payload.unwrap_or(row.request).to_vec();
let frame = CommandFrame {
id: row.id,
payload,
report_id: row.report_id,
expected_response: row.expected_response,
};
let encoded = frame.encode();
let bytes_written = self.transport.write(&encoded)?;
if row.expected_response == "none" {
let report = CommandExecutionReport {
command: row.id,
attempts: 1,
validator: self.validator_name(row),
status: ResponseStatus::Ok,
bytes_written,
bytes_read: 0,
error_code: None,
};
self.record_execution(report);
return Ok(ResponseFrame {
raw: Vec::new(),
status: ResponseStatus::Ok,
parsed_fields: BTreeMap::new(),
});
}
let timeout_ms = self.timeout_for_command(row);
let expected_min_len = minimum_response_len(row.id);
let attempts_total = self.config.retry_policy.max_attempts.max(1);
let mut last_status = ResponseStatus::Malformed;
let mut last_len = 0usize;
for attempt in 1..=attempts_total {
match self.read_response_reassembled(timeout_ms, expected_min_len) {
Ok(raw) => {
let status = validate_response(row.id, &raw);
if status == ResponseStatus::Ok {
let report = CommandExecutionReport {
command: row.id,
attempts: attempt,
validator: self.validator_name(row),
status: ResponseStatus::Ok,
bytes_written,
bytes_read: raw.len(),
error_code: None,
};
self.record_execution(report);
return Ok(ResponseFrame {
parsed_fields: parse_fields(row.id, &raw),
raw,
status,
});
}
last_status = status;
last_len = raw.len();
}
Err(BitdoError::Timeout) => {
last_status = ResponseStatus::Malformed;
last_len = 0;
}
Err(err) => {
let report = CommandExecutionReport {
command: row.id,
attempts: attempt,
validator: self.validator_name(row),
status: ResponseStatus::Malformed,
bytes_written,
bytes_read: 0,
error_code: Some(err.code()),
};
self.record_execution(report);
return Err(err);
}
}
if attempt < attempts_total && self.config.retry_policy.backoff_ms > 0 {
thread::sleep(Duration::from_millis(self.config.retry_policy.backoff_ms));
}
}
match last_status {
ResponseStatus::Invalid => {
let err = BitdoError::InvalidResponse {
command: row.id,
reason: "response signature mismatch".to_owned(),
};
let report = CommandExecutionReport {
command: row.id,
attempts: attempts_total,
validator: self.validator_name(row),
status: ResponseStatus::Invalid,
bytes_written,
bytes_read: last_len,
error_code: Some(err.code()),
};
self.record_execution(report);
Err(err)
}
_ => {
let err = BitdoError::MalformedResponse {
command: row.id,
len: last_len,
};
let report = CommandExecutionReport {
command: row.id,
attempts: attempts_total,
validator: self.validator_name(row),
status: ResponseStatus::Malformed,
bytes_written,
bytes_read: last_len,
error_code: Some(err.code()),
};
self.record_execution(report);
Err(err)
}
}
}
fn read_response_reassembled(
&mut self,
timeout_ms: u64,
expected_min_len: usize,
) -> Result<Vec<u8>> {
let mut raw = Vec::new();
// Some devices can split replies across multiple reads; reassemble bounded chunks.
for _ in 0..3 {
let chunk = self.transport.read(64, timeout_ms)?;
if chunk.is_empty() {
continue;
}
raw.extend_from_slice(&chunk);
if raw.len() >= expected_min_len {
break;
}
}
if raw.is_empty() {
return Err(BitdoError::Timeout);
}
Ok(raw)
}
fn record_execution(&mut self, report: CommandExecutionReport) {
self.last_execution = Some(report.clone());
if self.config.trace_enabled {
self.trace.push(report);
}
}
fn timeout_for_command(&self, row: &CommandRegistryRow) -> u64 {
match row.safety_class {
SafetyClass::UnsafeFirmware => self.config.timeout_profile.firmware_ms,
SafetyClass::SafeRead => self.config.timeout_profile.probe_ms,
SafetyClass::SafeWrite | SafetyClass::UnsafeBoot => self.config.timeout_profile.io_ms,
}
}
fn validator_name(&self, row: &CommandRegistryRow) -> String {
format!(
"pid={:#06x};signature={}",
self.target.pid, row.expected_response
)
}
fn ensure_command_allowed(&self, command: CommandId) -> Result<&'static CommandRegistryRow> {
let row = find_command(command).ok_or(BitdoError::UnknownCommand(command))?;
if row.confidence == CommandConfidence::Inferred && !self.config.experimental {
return Err(BitdoError::ExperimentalRequired { command });
}
if !is_command_allowed_by_family(self.profile.protocol_family, command)
|| !is_command_allowed_by_capability(self.profile.capability, command)
{
return Err(BitdoError::UnsupportedForPid {
command,
pid: self.target.pid,
});
}
if row.safety_class.is_unsafe() {
if self.profile.support_level != SupportLevel::Full {
return Err(BitdoError::UnsupportedForPid {
command,
pid: self.target.pid,
});
}
if !(self.config.allow_unsafe && self.config.brick_risk_ack) {
return Err(BitdoError::UnsafeCommandDenied { command });
}
}
if row.safety_class == SafetyClass::SafeWrite
&& self.profile.support_level == SupportLevel::DetectOnly
{
return Err(BitdoError::UnsupportedForPid {
command,
pid: self.target.pid,
});
}
Ok(row)
}
}
fn is_command_allowed_by_capability(cap: PidCapability, command: CommandId) -> bool {
match command {
CommandId::GetPid
| CommandId::GetReportRevision
| CommandId::GetControllerVersion
| CommandId::Version
| CommandId::Idle
| CommandId::GetSuperButton => true,
CommandId::GetMode | CommandId::GetModeAlt | CommandId::SetModeDInput => cap.supports_mode,
CommandId::ReadProfile | CommandId::WriteProfile => cap.supports_profile_rw,
CommandId::EnterBootloaderA
| CommandId::EnterBootloaderB
| CommandId::EnterBootloaderC
| CommandId::ExitBootloader => cap.supports_boot,
CommandId::FirmwareChunk | CommandId::FirmwareCommit => cap.supports_firmware,
}
}
fn is_command_allowed_by_family(family: ProtocolFamily, command: CommandId) -> bool {
match family {
ProtocolFamily::Unknown => matches!(
command,
CommandId::GetPid
| CommandId::GetReportRevision
| CommandId::GetControllerVersion
| CommandId::Version
| CommandId::Idle
),
ProtocolFamily::JpHandshake => !matches!(
command,
CommandId::SetModeDInput
| CommandId::ReadProfile
| CommandId::WriteProfile
| CommandId::FirmwareChunk
| CommandId::FirmwareCommit
),
ProtocolFamily::DS4Boot => matches!(
command,
CommandId::EnterBootloaderA
| CommandId::EnterBootloaderB
| CommandId::EnterBootloaderC
| CommandId::ExitBootloader
| CommandId::FirmwareChunk
| CommandId::FirmwareCommit
| CommandId::GetPid
),
ProtocolFamily::Standard64 | ProtocolFamily::DInput => true,
}
}
pub fn validate_response(command: CommandId, response: &[u8]) -> ResponseStatus {
if response.len() < 2 {
return ResponseStatus::Malformed;
}
match command {
CommandId::GetPid => {
if response.len() < 24 {
return ResponseStatus::Malformed;
}
if response[0] == 0x02 && response[1] == 0x05 && response[4] == 0xC1 {
ResponseStatus::Ok
} else {
ResponseStatus::Invalid
}
}
CommandId::GetReportRevision => {
if response.len() < 6 {
return ResponseStatus::Malformed;
}
if response[0] == 0x02 && response[1] == 0x04 && response[5] == 0x01 {
ResponseStatus::Ok
} else {
ResponseStatus::Invalid
}
}
CommandId::GetMode | CommandId::GetModeAlt => {
if response.len() < 6 {
return ResponseStatus::Malformed;
}
if response[0] == 0x02 && response[1] == 0x05 {
ResponseStatus::Ok
} else {
ResponseStatus::Invalid
}
}
CommandId::GetControllerVersion | CommandId::Version => {
if response.len() < 5 {
return ResponseStatus::Malformed;
}
if response[0] == 0x02 && response[1] == 0x22 {
ResponseStatus::Ok
} else {
ResponseStatus::Invalid
}
}
CommandId::Idle => {
if response[0] == 0x02 {
ResponseStatus::Ok
} else {
ResponseStatus::Invalid
}
}
CommandId::EnterBootloaderA
| CommandId::EnterBootloaderB
| CommandId::EnterBootloaderC
| CommandId::ExitBootloader => ResponseStatus::Ok,
_ => {
if response[0] == 0x02 {
ResponseStatus::Ok
} else {
ResponseStatus::Invalid
}
}
}
}
fn minimum_response_len(command: CommandId) -> usize {
match command {
CommandId::GetPid => 24,
CommandId::GetReportRevision => 6,
CommandId::GetMode | CommandId::GetModeAlt => 6,
CommandId::GetControllerVersion | CommandId::Version => 5,
_ => 2,
}
}
fn parse_fields(command: CommandId, response: &[u8]) -> BTreeMap<String, u32> {
let mut parsed = BTreeMap::new();
match command {
CommandId::GetPid if response.len() >= 24 => {
let pid = u16::from_le_bytes([response[22], response[23]]);
parsed.insert("detected_pid".to_owned(), pid as u32);
}
CommandId::GetMode | CommandId::GetModeAlt if response.len() >= 6 => {
parsed.insert("mode".to_owned(), response[5] as u32);
}
CommandId::GetControllerVersion | CommandId::Version if response.len() >= 5 => {
let fw = u16::from_le_bytes([response[2], response[3]]) as u32;
parsed.insert("version_x100".to_owned(), fw);
parsed.insert("beta".to_owned(), response[4] as u32);
}
_ => {}
}
parsed
}

View File

@@ -0,0 +1,126 @@
use crate::error::{BitdoError, Result};
use crate::types::VidPid;
use std::collections::VecDeque;
pub trait Transport {
fn open(&mut self, vid_pid: VidPid) -> Result<()>;
fn close(&mut self) -> Result<()>;
fn write(&mut self, data: &[u8]) -> Result<usize>;
fn read(&mut self, len: usize, timeout_ms: u64) -> Result<Vec<u8>>;
fn write_feature(&mut self, data: &[u8]) -> Result<usize>;
fn read_feature(&mut self, len: usize) -> Result<Vec<u8>>;
}
impl<T: Transport + ?Sized> Transport for Box<T> {
fn open(&mut self, vid_pid: VidPid) -> Result<()> {
(**self).open(vid_pid)
}
fn close(&mut self) -> Result<()> {
(**self).close()
}
fn write(&mut self, data: &[u8]) -> Result<usize> {
(**self).write(data)
}
fn read(&mut self, len: usize, timeout_ms: u64) -> Result<Vec<u8>> {
(**self).read(len, timeout_ms)
}
fn write_feature(&mut self, data: &[u8]) -> Result<usize> {
(**self).write_feature(data)
}
fn read_feature(&mut self, len: usize) -> Result<Vec<u8>> {
(**self).read_feature(len)
}
}
#[derive(Clone, Debug)]
pub enum MockReadEvent {
Data(Vec<u8>),
Timeout,
Error(String),
}
#[derive(Clone, Debug, Default)]
pub struct MockTransport {
opened: Option<VidPid>,
reads: VecDeque<MockReadEvent>,
feature_reads: VecDeque<MockReadEvent>,
writes: Vec<Vec<u8>>,
feature_writes: Vec<Vec<u8>>,
}
impl MockTransport {
pub fn push_read_data(&mut self, data: Vec<u8>) {
self.reads.push_back(MockReadEvent::Data(data));
}
pub fn push_read_timeout(&mut self) {
self.reads.push_back(MockReadEvent::Timeout);
}
pub fn push_read_error(&mut self, message: impl Into<String>) {
self.reads.push_back(MockReadEvent::Error(message.into()));
}
pub fn push_feature_read_data(&mut self, data: Vec<u8>) {
self.feature_reads.push_back(MockReadEvent::Data(data));
}
pub fn writes(&self) -> &[Vec<u8>] {
&self.writes
}
pub fn feature_writes(&self) -> &[Vec<u8>] {
&self.feature_writes
}
}
impl Transport for MockTransport {
fn open(&mut self, vid_pid: VidPid) -> Result<()> {
self.opened = Some(vid_pid);
Ok(())
}
fn close(&mut self) -> Result<()> {
self.opened = None;
Ok(())
}
fn write(&mut self, data: &[u8]) -> Result<usize> {
if self.opened.is_none() {
return Err(BitdoError::Transport("mock transport not open".to_owned()));
}
self.writes.push(data.to_vec());
Ok(data.len())
}
fn read(&mut self, _len: usize, _timeout_ms: u64) -> Result<Vec<u8>> {
match self.reads.pop_front() {
Some(MockReadEvent::Data(d)) => Ok(d),
Some(MockReadEvent::Timeout) => Err(BitdoError::Timeout),
Some(MockReadEvent::Error(msg)) => Err(BitdoError::Transport(msg)),
None => Err(BitdoError::Timeout),
}
}
fn write_feature(&mut self, data: &[u8]) -> Result<usize> {
if self.opened.is_none() {
return Err(BitdoError::Transport("mock transport not open".to_owned()));
}
self.feature_writes.push(data.to_vec());
Ok(data.len())
}
fn read_feature(&mut self, _len: usize) -> Result<Vec<u8>> {
match self.feature_reads.pop_front() {
Some(MockReadEvent::Data(d)) => Ok(d),
Some(MockReadEvent::Timeout) => Err(BitdoError::Timeout),
Some(MockReadEvent::Error(msg)) => Err(BitdoError::Transport(msg)),
None => Err(BitdoError::Timeout),
}
}
}

View File

@@ -0,0 +1,116 @@
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
use std::str::FromStr;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct VidPid {
pub vid: u16,
pub pid: u16,
}
impl VidPid {
pub const fn new(vid: u16, pid: u16) -> Self {
Self { vid, pid }
}
}
impl Display for VidPid {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:04x}:{:04x}", self.vid, self.pid)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum ProtocolFamily {
Standard64,
JpHandshake,
DInput,
DS4Boot,
Unknown,
}
impl FromStr for ProtocolFamily {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Standard64" => Ok(Self::Standard64),
"JpHandshake" => Ok(Self::JpHandshake),
"DInput" => Ok(Self::DInput),
"DS4Boot" => Ok(Self::DS4Boot),
"Unknown" => Ok(Self::Unknown),
_ => Err(format!("unsupported protocol family: {s}")),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum SupportLevel {
Full,
DetectOnly,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum SafetyClass {
SafeRead,
SafeWrite,
UnsafeBoot,
UnsafeFirmware,
}
impl SafetyClass {
pub fn is_unsafe(self) -> bool {
matches!(self, Self::UnsafeBoot | Self::UnsafeFirmware)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum CommandConfidence {
Confirmed,
Inferred,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum SupportEvidence {
Confirmed,
Inferred,
Untested,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct PidCapability {
pub supports_mode: bool,
pub supports_profile_rw: bool,
pub supports_boot: bool,
pub supports_firmware: bool,
}
impl PidCapability {
pub const fn full() -> Self {
Self {
supports_mode: true,
supports_profile_rw: true,
supports_boot: true,
supports_firmware: true,
}
}
pub const fn identify_only() -> Self {
Self {
supports_mode: false,
supports_profile_rw: false,
supports_boot: false,
supports_firmware: false,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct DeviceProfile {
pub vid_pid: VidPid,
pub name: String,
pub support_level: SupportLevel,
pub protocol_family: ProtocolFamily,
pub capability: PidCapability,
pub evidence: SupportEvidence,
}

View File

@@ -0,0 +1,20 @@
[package]
name = "bitdoctl"
version = "0.1.0"
edition = "2021"
license = "MIT"
[dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
serde_json = { workspace = true }
hex = { workspace = true }
bitdo_proto = { path = "../bitdo_proto" }
[dev-dependencies]
assert_cmd = "2.0"
predicates = "3.1"
[[test]]
name = "cli_snapshot"
path = "../../tests/cli_snapshot.rs"

View File

@@ -0,0 +1,518 @@
use anyhow::{anyhow, Result};
use bitdo_proto::{
command_registry, device_profile_for, enumerate_hid_devices, BitdoErrorCode, CommandId,
DeviceSession, FirmwareTransferReport, HidTransport, MockTransport, ProfileBlob, RetryPolicy,
SessionConfig, TimeoutProfile, Transport, VidPid,
};
use clap::{Parser, Subcommand};
use serde_json::json;
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Parser)]
#[command(name = "bitdoctl")]
#[command(about = "OpenBitdo clean-room protocol CLI")]
struct Cli {
#[arg(long)]
vid: Option<String>,
#[arg(long)]
pid: Option<String>,
#[arg(long)]
json: bool,
#[arg(long = "unsafe")]
allow_unsafe: bool,
#[arg(long = "i-understand-brick-risk")]
brick_risk_ack: bool,
#[arg(long)]
experimental: bool,
#[arg(long)]
mock: bool,
#[arg(long, default_value_t = 3)]
max_attempts: u8,
#[arg(long, default_value_t = 10)]
backoff_ms: u64,
#[arg(long, default_value_t = 200)]
probe_timeout_ms: u64,
#[arg(long, default_value_t = 400)]
io_timeout_ms: u64,
#[arg(long, default_value_t = 1200)]
firmware_timeout_ms: u64,
#[command(subcommand)]
command: Commands,
}
#[derive(Debug, Subcommand)]
enum Commands {
List,
Identify,
Diag {
#[command(subcommand)]
command: DiagCommand,
},
Profile {
#[command(subcommand)]
command: ProfileCommand,
},
Mode {
#[command(subcommand)]
command: ModeCommand,
},
Boot {
#[command(subcommand)]
command: BootCommand,
},
Fw {
#[command(subcommand)]
command: FwCommand,
},
}
#[derive(Debug, Subcommand)]
enum DiagCommand {
Probe,
}
#[derive(Debug, Subcommand)]
enum ProfileCommand {
Dump {
#[arg(long)]
slot: u8,
},
Apply {
#[arg(long)]
slot: u8,
#[arg(long)]
file: PathBuf,
},
}
#[derive(Debug, Subcommand)]
enum ModeCommand {
Get,
Set {
#[arg(long)]
mode: u8,
},
}
#[derive(Debug, Subcommand)]
enum BootCommand {
Enter,
Exit,
}
#[derive(Debug, Subcommand)]
enum FwCommand {
Write {
#[arg(long)]
file: PathBuf,
#[arg(long, default_value_t = 56)]
chunk_size: usize,
#[arg(long)]
dry_run: bool,
},
}
fn main() -> Result<()> {
let cli = Cli::parse();
if let Err(err) = run(cli) {
eprintln!("error: {err}");
return Err(err);
}
Ok(())
}
fn run(cli: Cli) -> Result<()> {
match &cli.command {
Commands::List => handle_list(&cli),
Commands::Identify
| Commands::Diag { .. }
| Commands::Profile { .. }
| Commands::Mode { .. }
| Commands::Boot { .. }
| Commands::Fw { .. } => {
let target = resolve_target(&cli)?;
let transport: Box<dyn Transport> = if cli.mock {
Box::new(mock_transport_for(&cli.command, target)?)
} else {
Box::new(HidTransport::new())
};
let config = SessionConfig {
retry_policy: RetryPolicy {
max_attempts: cli.max_attempts,
backoff_ms: cli.backoff_ms,
},
timeout_profile: TimeoutProfile {
probe_ms: cli.probe_timeout_ms,
io_ms: cli.io_timeout_ms,
firmware_ms: cli.firmware_timeout_ms,
},
allow_unsafe: cli.allow_unsafe,
brick_risk_ack: cli.brick_risk_ack,
experimental: cli.experimental,
trace_enabled: true,
};
let mut session = DeviceSession::new(transport, target, config)?;
match &cli.command {
Commands::Identify => {
let info = session.identify()?;
if cli.json {
println!("{}", serde_json::to_string_pretty(&info)?);
} else {
println!(
"target={} profile={} support={:?} family={:?} evidence={:?} capability={:?} detected_pid={}",
info.target,
info.profile_name,
info.support_level,
info.protocol_family,
info.evidence,
info.capability,
info.detected_pid
.map(|v| format!("{v:#06x}"))
.unwrap_or_else(|| "none".to_owned())
);
}
}
Commands::Diag { command } => match command {
DiagCommand::Probe => {
let diag = session.diag_probe();
if cli.json {
println!("{}", serde_json::to_string_pretty(&diag)?);
} else {
println!(
"diag target={} profile={} family={:?}",
diag.target, diag.profile_name, diag.protocol_family
);
for check in diag.command_checks {
println!(
" {:?}: ok={} code={}",
check.command,
check.ok,
check
.error_code
.map(|c| format!("{c:?}"))
.unwrap_or_else(|| "none".to_owned())
);
}
}
}
},
Commands::Mode { command } => match command {
ModeCommand::Get => {
let mode = session.get_mode()?;
print_mode(mode.mode, &mode.source, cli.json);
}
ModeCommand::Set { mode } => {
let mode_state = session.set_mode(*mode)?;
print_mode(mode_state.mode, &mode_state.source, cli.json);
}
},
Commands::Profile { command } => match command {
ProfileCommand::Dump { slot } => {
let profile = session.read_profile(*slot)?;
if cli.json {
println!(
"{}",
serde_json::to_string_pretty(&json!({
"slot": profile.slot,
"payload_hex": hex::encode(&profile.payload),
}))?
);
} else {
println!(
"slot={} payload_hex={}",
profile.slot,
hex::encode(&profile.payload)
);
}
}
ProfileCommand::Apply { slot, file } => {
let bytes = fs::read(file)?;
let parsed = ProfileBlob::from_bytes(&bytes)?;
let blob = ProfileBlob {
slot: *slot,
payload: parsed.payload,
};
session.write_profile(*slot, &blob)?;
if cli.json {
println!(
"{}",
serde_json::to_string_pretty(&json!({
"applied": true,
"slot": slot,
}))?
);
} else {
println!("applied profile to slot={slot}");
}
}
},
Commands::Boot { command } => {
match command {
BootCommand::Enter => session.enter_bootloader()?,
BootCommand::Exit => session.exit_bootloader()?,
}
if cli.json {
println!(
"{}",
serde_json::to_string_pretty(&json!({
"ok": true,
"command": format!("{:?}", command),
}))?
);
} else {
println!("{:?} completed", command);
}
}
Commands::Fw { command } => match command {
FwCommand::Write {
file,
chunk_size,
dry_run,
} => {
let image = fs::read(file)?;
let report = session.firmware_transfer(&image, *chunk_size, *dry_run)?;
print_fw_report(report, cli.json)?;
}
},
Commands::List => unreachable!(),
}
session.close()?;
Ok(())
}
}
}
fn handle_list(cli: &Cli) -> Result<()> {
if cli.mock {
let profile = device_profile_for(VidPid::new(0x2dc8, 0x6009));
if cli.json {
println!(
"{}",
serde_json::to_string_pretty(&vec![json!({
"vid": "0x2dc8",
"pid": "0x6009",
"product": "Mock 8BitDo Device",
"support_level": format!("{:?}", profile.support_level),
"protocol_family": format!("{:?}", profile.protocol_family),
"capability": profile.capability,
"evidence": format!("{:?}", profile.evidence),
})])?
);
} else {
println!("2dc8:6009 Mock 8BitDo Device");
}
return Ok(());
}
let devices = enumerate_hid_devices()?;
let filtered: Vec<_> = devices
.into_iter()
.filter(|d| d.vid_pid.vid == 0x2dc8)
.collect();
if cli.json {
let out: Vec<_> = filtered
.iter()
.map(|d| {
let profile = device_profile_for(d.vid_pid);
json!({
"vid": format!("{:#06x}", d.vid_pid.vid),
"pid": format!("{:#06x}", d.vid_pid.pid),
"product": d.product,
"manufacturer": d.manufacturer,
"serial": d.serial,
"path": d.path,
"support_level": format!("{:?}", profile.support_level),
"protocol_family": format!("{:?}", profile.protocol_family),
"capability": profile.capability,
"evidence": format!("{:?}", profile.evidence),
})
})
.collect();
println!("{}", serde_json::to_string_pretty(&out)?);
} else {
for d in &filtered {
println!(
"{} {}",
d.vid_pid,
d.product.as_deref().unwrap_or("(unknown product)")
);
}
}
Ok(())
}
fn resolve_target(cli: &Cli) -> Result<VidPid> {
let vid = cli
.vid
.as_deref()
.map(parse_u16)
.transpose()?
.unwrap_or(0x2dc8);
let pid_str = cli
.pid
.as_deref()
.ok_or_else(|| anyhow!("--pid is required for this command"))?;
let pid = parse_u16(pid_str)?;
Ok(VidPid::new(vid, pid))
}
fn parse_u16(input: &str) -> Result<u16> {
if let Some(hex) = input
.strip_prefix("0x")
.or_else(|| input.strip_prefix("0X"))
{
return Ok(u16::from_str_radix(hex, 16)?);
}
Ok(input.parse::<u16>()?)
}
fn mock_transport_for(command: &Commands, target: VidPid) -> Result<MockTransport> {
let mut t = MockTransport::default();
match command {
Commands::Identify => {
t.push_read_data(build_pid_response(target.pid));
}
Commands::Diag { command } => match command {
DiagCommand::Probe => {
t.push_read_data(build_pid_response(target.pid));
t.push_read_data(build_rr_response());
t.push_read_data(build_mode_response(2));
t.push_read_data(build_version_response());
}
},
Commands::Mode { command } => match command {
ModeCommand::Get => t.push_read_data(build_mode_response(2)),
ModeCommand::Set { mode } => {
t.push_read_data(build_ack_response());
t.push_read_data(build_mode_response(*mode));
}
},
Commands::Profile { command } => match command {
ProfileCommand::Dump { slot } => {
let mut raw = vec![0x02, 0x06, 0x00, *slot];
raw.extend_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
t.push_read_data(raw);
}
ProfileCommand::Apply { .. } => {
t.push_read_data(build_ack_response());
}
},
Commands::Boot { .. } => {}
Commands::Fw { command } => {
let chunks = match command {
FwCommand::Write {
file,
chunk_size,
dry_run,
} => {
if *dry_run {
0
} else {
let sz = fs::metadata(file).map(|m| m.len() as usize).unwrap_or(0);
sz.div_ceil(*chunk_size) + 1
}
}
};
for _ in 0..chunks {
t.push_read_data(build_ack_response());
}
}
Commands::List => {}
}
if matches!(command, Commands::Profile { .. } | Commands::Fw { .. })
&& !command_registry()
.iter()
.any(|c| c.id == CommandId::ReadProfile)
{
return Err(anyhow!("command registry is empty"));
}
Ok(t)
}
fn build_ack_response() -> Vec<u8> {
vec![0x02, 0x01, 0x00, 0x00]
}
fn build_mode_response(mode: u8) -> Vec<u8> {
let mut out = vec![0u8; 64];
out[0] = 0x02;
out[1] = 0x05;
out[5] = mode;
out
}
fn build_rr_response() -> Vec<u8> {
let mut out = vec![0u8; 64];
out[0] = 0x02;
out[1] = 0x04;
out[5] = 0x01;
out
}
fn build_version_response() -> Vec<u8> {
let mut out = vec![0u8; 64];
out[0] = 0x02;
out[1] = 0x22;
out[2] = 0x2A;
out[3] = 0x00;
out[4] = 0x01;
out
}
fn build_pid_response(pid: u16) -> Vec<u8> {
let mut out = vec![0u8; 64];
out[0] = 0x02;
out[1] = 0x05;
out[4] = 0xC1;
let [lo, hi] = pid.to_le_bytes();
out[22] = lo;
out[23] = hi;
out
}
fn print_mode(mode: u8, source: &str, as_json: bool) {
if as_json {
println!(
"{}",
serde_json::to_string_pretty(&json!({
"mode": mode,
"source": source,
}))
.expect("json serialization")
);
} else {
println!("mode={} source={}", mode, source);
}
}
fn print_fw_report(report: FirmwareTransferReport, as_json: bool) -> Result<()> {
if as_json {
println!("{}", serde_json::to_string_pretty(&report)?);
} else {
println!(
"bytes_total={} chunk_size={} chunks_sent={} dry_run={}",
report.bytes_total, report.chunk_size, report.chunks_sent, report.dry_run
);
}
Ok(())
}
#[allow(dead_code)]
fn print_error_code(code: BitdoErrorCode, as_json: bool) {
if as_json {
println!(
"{}",
serde_json::to_string_pretty(&json!({ "error_code": format!("{:?}", code) }))
.expect("json serialization")
);
} else {
println!("error_code={:?}", code);
}
}