src/ct_expand.erl

%% -*- erlang-indent-level: 4;indent-tabs-mode: nil -*-
%% --------------------------------------------------
%% This file is provided to you under the Apache License,
%% Version 2.0 (the "License"); you may not use this file
%% except in compliance with the License.  You may obtain
%% a copy of the License at
%%
%%   http://www.apache.org/licenses/LICENSE-2.0
%%
%% Unless required by applicable law or agreed to in writing,
%% software distributed under the License is distributed on an
%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
%% KIND, either express or implied.  See the License for the
%% specific language governing permissions and limitations
%% under the License.
%% --------------------------------------------------
%% File    : ct_expand.erl
%% @author  : Ulf Wiger <ulf@wiger.net>
%% @end
%% Created : 7 Apr 2010 by Ulf Wiger <ulf@wiger.net>
%%-------------------------------------------------------------------

%% @doc Compile-time expansion utility
%%
%% This module serves as an example of parse_trans-based transforms,
%% but might also be a useful utility in its own right.
%% The transform searches for calls to the pseudo-function
%% `ct_expand:term(Expr)', and then replaces the call site with the
%% result of evaluating `Expr' at compile-time.
%%
%% For example, the line
%%
%% `ct_expand:term(lists:sort([3,5,2,1,4]))'
%%
%% would be expanded at compile-time to `[1,2,3,4,5]'.
%%
%% ct_expand has now been extended to also evaluate calls to local functions.
%% See examples/ct_expand_test.erl for some examples.
%%
%% A debugging facility exists: passing the option {ct_expand_trace, Flags} as an option,
%% or adding a compiler attribute -ct_expand_trace(Flags) will enable a form of call trace.
%%
%% `Flags' can be `[]' (no trace) or `[F]', where `F' is `c' (call trace),
%% `r' (return trace), or `x' (exception trace)'.
%%
%% @end
-module(ct_expand).
-export([parse_transform/2]).

-export([extract_fun/3,
         lfun_rewrite/2]).

-type form()    :: any().
-type forms()   :: [form()].
-type options() :: [{atom(), any()}].


-spec parse_transform(forms(), options()) ->
    forms().
parse_transform(Forms, Options) ->
    Trace = ct_trace_opt(Options, Forms),
    case parse_trans:depth_first(fun(T,F,C,A) ->
                                         xform_fun(T,F,C,A,Forms, Trace)
                                 end, [], Forms, Options) of
        {error, Es} ->
            Es ++ Forms;
        {NewForms, _} ->
            parse_trans:revert(NewForms)
    end.

ct_trace_opt(Options, Forms) ->
    case proplists:get_value(ct_expand_trace, Options) of
        undefined ->
            case [Opt || {attribute,_,compile,{ct_expand_trace,Opt}} <- Forms] of
                [] ->
                    [];
                [_|_] = L ->
                    lists:last(L)
            end;
        Flags when is_list(Flags) ->
            Flags
    end.

xform_fun(application, Form, _Ctxt, Acc, Forms, Trace) ->
    MFA = erl_syntax_lib:analyze_application(Form),
    case MFA of
        {?MODULE, {term, 1}} ->
            LFH = fun(Name, Args, Bs) ->
                          eval_lfun(
                            extract_fun(Name, length(Args), Forms),
                            Args, Bs, Forms, Trace)
                  end,
            Args = erl_syntax:application_arguments(Form),
            RevArgs = parse_trans:revert(Args),
            case erl_eval:exprs(RevArgs, [], {eval, LFH}) of
                {value, Value,[]} ->
                    {abstract(Value), Acc};
                Other ->
                    parse_trans:error(cannot_evaluate,?LINE,
                                      [{expr, RevArgs},
                                       {error, Other}])
            end;
        _ ->
            {Form, Acc}
    end;
xform_fun(_, Form, _Ctxt, Acc, _, _) ->
    {Form, Acc}.

extract_fun(Name, Arity, Forms) ->
    case [F_ || {function,_,N_,A_,_Cs} = F_ <- Forms,
                N_ == Name, A_ == Arity] of
        [] ->
            erlang:error({undef, [{Name, Arity}]});
        [FForm] ->
            FForm
    end.

eval_lfun({function,L,F,_,Clauses}, Args, Bs, Forms, Trace) ->
    Line = erl_anno:line(L),
    try
        {ArgsV, Bs1} = lists:mapfoldl(
                         fun(A, Bs_) ->
                                 {value,AV,Bs1_} =
                                     erl_eval:expr(A, Bs_, lfh(Forms, Trace)),
                                 {abstract(AV), Bs1_}
                         end, Bs, Args),
        Expr = {call, L, {'fun', L, {clauses, lfun_rewrite(Clauses, Forms)}}, ArgsV},
        call_trace(Trace =/= [], Line, F, ArgsV),
        {value, Ret, _} =
            erl_eval:expr(Expr, erl_eval:new_bindings(), lfh(Forms, Trace)),
        ret_trace(lists:member(r, Trace) orelse lists:member(x, Trace),
                  Line, F, Args, Ret),
        %% restore bindings
        {value, Ret, Bs1}
    catch
        error:Err ->
            exception_trace(lists:member(x, Trace), Line, F, Args, Err),
            error(Err)
    end.

lfh(Forms, Trace) ->
    {eval, fun(Name, As, Bs1) ->
                   eval_lfun(
                     extract_fun(Name, length(As), Forms),
                     As, Bs1, Forms, Trace)
           end}.

call_trace(false, _, _, _) -> ok;
call_trace(true, L, F, As) ->
    io:fwrite("ct_expand (~w): call ~s~n", [L, pp_function(F, As)]).

pp_function(F, []) ->
    atom_to_list(F) ++ "()";
pp_function(F, [A|As]) ->
    lists:flatten([atom_to_list(F), "(",
                   [pp_term(A) |
                    [[",", pp_term(A_)] || A_ <- As]],
                   ")"]).

pp_term({'fun',_, {clauses,_}} = F) ->
    %% erl_parse:normalise/1 doesn't handle this
    io_lib:fwrite("~s", [erl_prettypr:format(F)]);
pp_term(F) ->
    io_lib:fwrite("~p", [erl_parse:normalise(F)]).

ret_trace(false, _, _, _, _) -> ok;
ret_trace(true, L, F, Args, Res) ->
    io:fwrite("ct_expand (~w): returned from ~w/~w: ~w~n",
              [L, F, length(Args), Res]).

exception_trace(false, _, _, _, _) -> ok;
exception_trace(true, L, F, Args, Err) ->
    io:fwrite("ct_expand (~w): exception from ~w/~w: ~p~n", [L, F, length(Args), Err]).


lfun_rewrite(Exprs, Forms) ->
    parse_trans:plain_transform(
      fun({'fun',L,{function,F,A}}) ->
              {function,_,_,_,Cs} = extract_fun(F, A, Forms),
              {'fun',L,{clauses, Cs}};
         (_) ->
              continue
      end, Exprs).


%% abstract/1 - modified from erl_eval:abstract/1:
-type abstract_expr() :: term().
-spec abstract(Data) -> AbsTerm when
      Data :: term(),
      AbsTerm :: abstract_expr().
abstract(T) ->
    abstract(T, erl_anno:new(0)).

abstract(T, A) when is_function(T) ->
    case erlang:fun_info(T, module) of
        {module, erl_eval} ->
            case erl_eval:fun_data(T) of
                {fun_data, _Imports, Clauses} ->
                    {'fun', A, {clauses, Clauses}};
                false ->
                    erlang:error(function_clause)  % mimicking erl_parse:abstract(T)
            end;
        _ ->
            erlang:error(function_clause)
    end;
abstract(T, A) when is_integer(T) -> {integer,A,T};
abstract(T, A) when is_float(T) -> {float,A,T};
abstract(T, A) when is_atom(T) -> {atom,A,T};
abstract([], A) -> {nil,A};
abstract(B, A) when is_bitstring(B) ->
    {bin, A, [abstract_byte(Byte, A) || Byte <- bitstring_to_list(B)]};
abstract([C|T], A) when is_integer(C), 0 =< C, C < 256 ->
    abstract_string(T, [C], A);
abstract([H|T], A) ->
    {cons,A,abstract(H, A),abstract(T, A)};
abstract(Map, A) when is_map(Map) ->
    {map,A,abstract_map(Map, A)};
abstract(Tuple, A) when is_tuple(Tuple) ->
    {tuple,A,abstract_list(tuple_to_list(Tuple), A)}.

abstract_string([C|T], String, A) when is_integer(C), 0 =< C, C < 256 ->
    abstract_string(T, [C|String], A);
abstract_string([], String, A) ->
    {string, A, lists:reverse(String)};
abstract_string(T, String, A) ->
    not_string(String, abstract(T, A), A).

not_string([C|T], Result, A) ->
    not_string(T, {cons, A, {integer, A, C}, Result}, A);
not_string([], Result, _A) ->
    Result.

abstract_list([H|T], A) ->
    [abstract(H, A)|abstract_list(T, A)];
abstract_list([], _A) ->
    [].

abstract_map(Map, A) ->
    [{map_field_assoc,A,abstract(K, A),abstract(V, A)}
     || {K,V} <- lists:sort(maps:to_list(Map))
    ].

abstract_byte(Byte, Line) when is_integer(Byte) ->
    {bin_element, Line, {integer, Line, Byte}, default, default};
abstract_byte(Bits, Line) ->
    Sz = bit_size(Bits),
    <<Val:Sz>> = Bits,
    {bin_element, Line, {integer, Line, Val}, {integer, Line, Sz}, default}.