Skip to content

Commit

Permalink
[arith] Allow scalar access of BASH_LINENO, etc. in arith.
Browse files Browse the repository at this point in the history
  • Loading branch information
akinomyoga committed Apr 23, 2020
1 parent 9fc20e7 commit c2365f5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
14 changes: 14 additions & 0 deletions osh/sh_expr_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from mycpp.mylib import tagswitch, switch
from osh import bool_stat
from osh import word_
from osh import word_eval

import libc # for fnmatch

Expand Down Expand Up @@ -327,6 +328,12 @@ def _EvalLhsAndLookupArith(self, node):
lval = self.EvalArithLhs(node, runtime.NO_SPID)
val = OldValue(lval, self.mem, self.exec_opts)

# BASH_LINENO, etc.
if val.tag_() in (value_e.MaybeStrArray, value_e.AssocArray) and lval.tag_() == lvalue_e.Named:
named_lval = cast(lvalue__Named, lval)
if word_eval.CheckCompatArray(named_lval.name):
val = word_eval.ResolveCompatArray(val)

# This error message could be better, but we already have one
#if val.tag_() == value_e.MaybeStrArray:
# e_die("Can't use assignment like ++ or += on arrays")
Expand All @@ -347,6 +354,13 @@ def EvalToInt(self, node):
Also used internally.
"""
val = self.Eval(node)

# BASH_LINENO, etc.
if val.tag_() in (value_e.MaybeStrArray, value_e.AssocArray) and node.tag_() == arith_expr_e.VarRef:
tok = cast(Token, node)
if word_eval.CheckCompatArray(tok.val):
val = word_eval.ResolveCompatArray(val)

# TODO: Can we avoid the runtime cost of adding location info?
span_id = location.SpanForArithExpr(node)
i = self._ValToIntOrError(val, span_id=span_id)
Expand Down
45 changes: 24 additions & 21 deletions osh/word_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@
# For compatibility, ${BASH_SOURCE} and ${BASH_SOURCE[@]} are both valid.
# ${FUNCNAME} and ${BASH_LINENO} are also the same type of of special variables.
_STRING_AND_ARRAY = ['BASH_SOURCE', 'FUNCNAME', 'BASH_LINENO']
def CheckCompatArray(var_name):
# type: (str) -> bool
return var_name in _STRING_AND_ARRAY

def ResolveCompatArray(val):
# type: (value_t) -> value_t
"""Decay ${array} to ${array[0]}."""
if val.tag_() == value_e.MaybeStrArray:
array_val = cast(value__MaybeStrArray, val)
s = array_val.strs[0] if array_val.strs else None
elif val.tag_() == value_e.AssocArray:
assoc_val = cast(value__AssocArray, val)
s = assoc_val.d['0'] if '0' in assoc_val.d else None
else:
raise AssertionError(val.tag_())

if s is None:
return value.Undef()
else:
return value.Str(s)


def EvalSingleQuoted(part):
Expand Down Expand Up @@ -794,23 +814,6 @@ def _DecayArray(self, val):
tmp = [s for s in val.strs if s is not None]
return value.Str(sep.join(tmp))

def _BashArrayCompat(self, val):
# type: (value_t) -> value_t
"""Decay ${array} to ${array[0]}."""
if val.tag_() == value_e.MaybeStrArray:
array_val = cast(value__MaybeStrArray, val)
s = array_val.strs[0] if array_val.strs else None
elif val.tag_() == value_e.AssocArray:
assoc_val = cast(value__AssocArray, val)
s = assoc_val.d['0'] if '0' in assoc_val.d else None
else:
raise AssertionError(val.tag_())

if s is None:
return value.Undef()
else:
return value.Str(s)

def _EmptyStrOrError(self, val, token=None):
# type: (value_t, Optional[Token]) -> value_t
if val.tag_() == value_e.Undef:
Expand Down Expand Up @@ -987,9 +990,9 @@ def _EvalBracedVarSub(self, part, part_vals, quoted):
# ${array@a} is a string
# TODO: An IR for ${} might simplify these lengthy conditions
pass
elif var_name in _STRING_AND_ARRAY:
elif CheckCompatArray(var_name):
# for ${BASH_SOURCE}, etc.
val = self._BashArrayCompat(val)
val = ResolveCompatArray(val)
else:
e_die("Array %r can't be referred to as a scalar (without @ or *)",
var_name, part=part)
Expand Down Expand Up @@ -1216,9 +1219,9 @@ def _EvalSimpleVarSub(self, token, part_vals, quoted):
# TODO: Special case for LINENO
val = self.mem.GetVar(var_name)
if val.tag_() in (value_e.MaybeStrArray, value_e.AssocArray):
if var_name in _STRING_AND_ARRAY:
if CheckCompatArray(var_name):
# for $BASH_SOURCE, etc.
val = self._BashArrayCompat(val)
val = ResolveCompatArray(val)
else:
e_die("Array %r can't be referred to as a scalar (without @ or *)",
var_name, token=token)
Expand Down

0 comments on commit c2365f5

Please sign in to comment.