use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::Type;
use sqlx_core::column::Column;
use sqlx_core::describe::Describe;
use crate::database::DatabaseExt;
use crate::query::QueryMacroInput;
use std::fmt::{self, Display, Formatter};
use syn::parse::{Parse, ParseStream};
use syn::Token;
pub struct RustColumn {
pub(super) ident: Ident,
pub(super) var_name: Ident,
pub(super) type_: ColumnType,
}
pub(super) enum ColumnType {
Exact(TokenStream),
Wildcard,
OptWildcard,
}
impl ColumnType {
pub(super) fn is_wildcard(&self) -> bool {
!matches!(self, ColumnType::Exact(_))
}
}
impl ToTokens for ColumnType {
fn to_tokens(&self, tokens: &mut TokenStream) {
tokens.append_all(match &self {
ColumnType::Exact(type_) => type_.clone().into_iter(),
ColumnType::Wildcard => quote! { _ }.into_iter(),
ColumnType::OptWildcard => quote! { ::std::option::Option<_> }.into_iter(),
})
}
}
struct DisplayColumn<'a> {
idx: usize,
name: &'a str,
}
struct ColumnDecl {
ident: Ident,
r#override: ColumnOverride,
}
struct ColumnOverride {
nullability: ColumnNullabilityOverride,
type_: ColumnTypeOverride,
}
#[derive(PartialEq)]
enum ColumnNullabilityOverride {
NonNull,
Nullable,
None,
}
enum ColumnTypeOverride {
Exact(Type),
Wildcard,
None,
}
impl Display for DisplayColumn<'_> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "column #{} ({:?})", self.idx + 1, self.name)
}
}
pub fn columns_to_rust<DB: DatabaseExt>(describe: &Describe<DB>) -> crate::Result<Vec<RustColumn>> {
(0..describe.columns().len())
.map(|i| column_to_rust(describe, i))
.collect::<crate::Result<Vec<_>>>()
}
fn column_to_rust<DB: DatabaseExt>(describe: &Describe<DB>, i: usize) -> crate::Result<RustColumn> {
let column = &describe.columns()[i];
let decl = ColumnDecl::parse(&column.name())
.map_err(|e| format!("column name {:?} is invalid: {}", column.name(), e))?;
let ColumnOverride { nullability, type_ } = decl.r#override;
let nullable = match nullability {
ColumnNullabilityOverride::NonNull => false,
ColumnNullabilityOverride::Nullable => true,
ColumnNullabilityOverride::None => describe.nullable(i).unwrap_or(true),
};
let type_ = match (type_, nullable) {
(ColumnTypeOverride::Exact(type_), false) => ColumnType::Exact(type_.to_token_stream()),
(ColumnTypeOverride::Exact(type_), true) => {
ColumnType::Exact(quote! { ::std::option::Option<#type_> })
}
(ColumnTypeOverride::Wildcard, false) => ColumnType::Wildcard,
(ColumnTypeOverride::Wildcard, true) => ColumnType::OptWildcard,
(ColumnTypeOverride::None, _) => {
let type_ = get_column_type::<DB>(i, column);
if !nullable {
ColumnType::Exact(type_)
} else {
ColumnType::Exact(quote! { ::std::option::Option<#type_> })
}
}
};
Ok(RustColumn {
var_name: quote::format_ident!("sqlx_query_as_{}", decl.ident),
ident: decl.ident,
type_,
})
}
pub fn quote_query_as<DB: DatabaseExt>(
input: &QueryMacroInput,
out_ty: &Type,
bind_args: &Ident,
columns: &[RustColumn],
) -> TokenStream {
let instantiations = columns.iter().enumerate().map(
|(
i,
&RustColumn {
ref var_name,
ref type_,
..
},
)| {
match (input.checked, type_) {
(true, ColumnType::Exact(type_)) => quote! {
let #var_name = row.try_get_unchecked::<#type_, _>(#i)?;
},
(true, ColumnType::Wildcard) => quote! ( let #var_name = row.try_get(#i)?; ),
(true, ColumnType::OptWildcard) => {
quote! ( let #var_name = row.try_get::<::std::option::Option<_>, _>(#i)?; )
}
(false, _) => quote!( let #var_name = row.try_get_unchecked(#i)?; ),
}
},
);
let ident = columns.iter().map(|col| &col.ident);
let var_name = columns.iter().map(|col| &col.var_name);
let db_path = DB::db_path();
let row_path = DB::row_path();
let sql = if let Some(ref path) = &input.file_path {
quote::quote_spanned! { input.src_span => include_str!(#path) }
} else {
let sql = &input.sql;
quote! { #sql }
};
quote! {
::sqlx::query_with::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| {
use ::sqlx::Row as _;
#(#instantiations)*
Ok(#out_ty { #(#ident: #var_name),* })
})
}
}
pub fn quote_query_scalar<DB: DatabaseExt>(
input: &QueryMacroInput,
bind_args: &Ident,
describe: &Describe<DB>,
) -> crate::Result<TokenStream> {
let columns = describe.columns();
if columns.len() != 1 {
return Err(syn::Error::new(
input.src_span,
format!("expected exactly 1 column, got {}", columns.len()),
)
.into());
}
let ty = if let Ok(rust_col) = column_to_rust(describe, 0) {
rust_col.type_.to_token_stream()
} else if input.checked {
let ty = get_column_type::<DB>(0, &columns[0]);
if describe.nullable(0).unwrap_or(true) {
quote! { ::std::option::Option<#ty> }
} else {
ty
}
} else {
quote! { _ }
};
let db = DB::db_path();
let query = &input.sql;
Ok(quote! {
::sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args)
})
}
fn get_column_type<DB: DatabaseExt>(i: usize, column: &DB::Column) -> TokenStream {
let type_info = &*column.type_info();
<DB as DatabaseExt>::return_type_for_id(&type_info).map_or_else(
|| {
let message =
if let Some(feature_gate) = <DB as DatabaseExt>::get_feature_gate(&type_info) {
format!(
"optional feature `{feat}` required for type {ty} of {col}",
ty = &type_info,
feat = feature_gate,
col = DisplayColumn {
idx: i,
name: &*column.name()
}
)
} else {
format!(
"unsupported type {ty} of {col}",
ty = type_info,
col = DisplayColumn {
idx: i,
name: &*column.name()
}
)
};
syn::Error::new(Span::call_site(), message).to_compile_error()
},
|t| t.parse().unwrap(),
)
}
impl ColumnDecl {
fn parse(col_name: &str) -> crate::Result<Self> {
let (ident, remainder) = if let Some(i) = col_name.find(&[':', '!', '?'][..]) {
let (ident, remainder) = col_name.split_at(i);
(parse_ident(ident)?, remainder)
} else {
(parse_ident(col_name)?, "")
};
Ok(ColumnDecl {
ident,
r#override: if !remainder.is_empty() {
syn::parse_str(remainder)?
} else {
ColumnOverride {
nullability: ColumnNullabilityOverride::None,
type_: ColumnTypeOverride::None,
}
},
})
}
}
impl Parse for ColumnOverride {
fn parse(input: ParseStream) -> syn::Result<Self> {
let lookahead = input.lookahead1();
let nullability = if lookahead.peek(Token![!]) {
input.parse::<Token![!]>()?;
ColumnNullabilityOverride::NonNull
} else if lookahead.peek(Token![?]) {
input.parse::<Token![?]>()?;
ColumnNullabilityOverride::Nullable
} else {
ColumnNullabilityOverride::None
};
let type_ = if input.lookahead1().peek(Token![:]) {
input.parse::<Token![:]>()?;
let ty = Type::parse(input)?;
if let Type::Infer(_) = ty {
ColumnTypeOverride::Wildcard
} else {
ColumnTypeOverride::Exact(ty)
}
} else {
ColumnTypeOverride::None
};
Ok(Self { nullability, type_ })
}
}
fn parse_ident(name: &str) -> crate::Result<Ident> {
let is_valid_ident = !name.is_empty()
&& name.starts_with(|c: char| c.is_alphabetic() || c == '_')
&& name.chars().all(|c| c.is_alphanumeric() || c == '_');
if is_valid_ident {
let ident = String::from("r#") + name;
if let Ok(ident) = syn::parse_str(&ident) {
return Ok(ident);
}
}
Err(format!("{:?} is not a valid Rust identifier", name).into())
}