From 1c8a7a87a8ae2bacb2e03b754718f76a79a2b918 Mon Sep 17 00:00:00 2001 From: Stephane Epardaud Date: Fri, 24 Nov 2017 14:21:55 +0100 Subject: [PATCH 1/2] Support upgrading foreach loops wrt. subtyping to susport suspendable iterators #285 --- .../CheckInstrumentationVisitor.java | 7 +- .../fibers/instrument/InstrumentClass.java | 5 +- .../fibers/instrument/InstrumentMethod.java | 162 ++++++++++++++++-- .../fibers/instrument/MethodDatabase.java | 83 +++++++-- 4 files changed, 225 insertions(+), 32 deletions(-) diff --git a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/CheckInstrumentationVisitor.java b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/CheckInstrumentationVisitor.java index d2bfe27420..d2689d5969 100644 --- a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/CheckInstrumentationVisitor.java +++ b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/CheckInstrumentationVisitor.java @@ -96,7 +96,7 @@ public boolean isAlreadyInstrumented() { public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { this.className = name; this.isInterface = (access & Opcodes.ACC_INTERFACE) != 0; - this.classEntry = new ClassEntry(superName); + this.classEntry = new ClassEntry(name, superName); classEntry.setInterfaces(interfaces); classEntry.setIsInterface(isInterface); } @@ -137,7 +137,7 @@ public MethodVisitor visitMethod(final int access, final String name, final Stri } } suspendable = InstrumentClass.suspendableToSuperIfAbstract(access, suspendable); - classEntry.set(name, desc, suspendable); + classEntry.set(name, desc, suspendable, (access & Opcodes.ACC_BRIDGE) != 0); if (suspendable == null) // look for @Suspendable annotation return new MethodVisitor(ASMAPI) { @@ -153,7 +153,8 @@ public AnnotationVisitor visitAnnotation(String adesc, boolean visible) { @Override public void visitEnd() { super.visitEnd(); - classEntry.set(name, desc, InstrumentClass.suspendableToSuperIfAbstract(access, susp ? SuspendableType.SUSPENDABLE : SuspendableType.NON_SUSPENDABLE)); + classEntry.set(name, desc, InstrumentClass.suspendableToSuperIfAbstract(access, susp ? SuspendableType.SUSPENDABLE : SuspendableType.NON_SUSPENDABLE), + (access & Opcodes.ACC_BRIDGE) != 0); hasSuspendable = hasSuspendable | susp; } }; diff --git a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentClass.java b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentClass.java index 7187bd7d79..7505bb8e0c 100644 --- a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentClass.java +++ b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentClass.java @@ -148,7 +148,8 @@ public MethodVisitor visitMethod(final int access, final String name, final Stri final SuspendableType setSuspendable = classEntry.check(name, desc); if (setSuspendable == null) - classEntry.set(name, desc, markedSuspendable != null ? markedSuspendable : SuspendableType.NON_SUSPENDABLE); + classEntry.set(name, desc, markedSuspendable != null ? markedSuspendable : SuspendableType.NON_SUSPENDABLE, + (access & Opcodes.ACC_BRIDGE) != 0); final SuspendableType suspendable = max(markedSuspendable, setSuspendable, SuspendableType.NON_SUSPENDABLE); @@ -200,7 +201,7 @@ private void commit() { if (db.isDebug()) db.log(LogLevel.INFO, "Method %s#%s%s suspendable: %s (markedSuspendable: %s setSuspendable: %s)", className, name, desc, susp, susp, setSuspendable); - classEntry.set(name, desc, susp); + classEntry.set(name, desc, susp, (access & Opcodes.ACC_BRIDGE) != 0); if (susp == SuspendableType.SUSPENDABLE && checkAccessForMethodInstrumentation(access)) { if (isSynchronized(access)) { diff --git a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentMethod.java b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentMethod.java index 26a97bec8e..0f51ace6ad 100644 --- a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentMethod.java +++ b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/InstrumentMethod.java @@ -60,27 +60,15 @@ import static co.paralleluniverse.fibers.instrument.MethodDatabase.isMethodHandleInvocation; import static co.paralleluniverse.fibers.instrument.MethodDatabase.isReflectInvocation; import static co.paralleluniverse.fibers.instrument.MethodDatabase.isSyntheticAccess; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; + +import java.util.*; + import org.objectweb.asm.AnnotationVisitor; import org.objectweb.asm.Label; import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; -import org.objectweb.asm.tree.AbstractInsnNode; -import org.objectweb.asm.tree.AnnotationNode; -import org.objectweb.asm.tree.InsnList; -import org.objectweb.asm.tree.InvokeDynamicInsnNode; -import org.objectweb.asm.tree.JumpInsnNode; -import org.objectweb.asm.tree.LabelNode; -import org.objectweb.asm.tree.LineNumberNode; -import org.objectweb.asm.tree.LocalVariableNode; -import org.objectweb.asm.tree.MethodInsnNode; -import org.objectweb.asm.tree.MethodNode; -import org.objectweb.asm.tree.TryCatchBlockNode; +import org.objectweb.asm.tree.*; import org.objectweb.asm.tree.analysis.Analyzer; import org.objectweb.asm.tree.analysis.AnalyzerException; import org.objectweb.asm.tree.analysis.BasicValue; @@ -144,6 +132,7 @@ class InstrumentMethod { this.mn = mn; try { + upgradeForeach(mn); Analyzer a = new TypeAnalyzer(db); this.frames = a.analyze(className, mn); this.lvarStack = mn.maxLocals; @@ -157,6 +146,147 @@ class InstrumentMethod { } } + public void upgradeForeach(MethodNode mn) { + ListIterator it = mn.instructions.iterator(); + int i = 0; + while(it.hasNext()) { + AbstractInsnNode instr = it.next(); + if(instr.getType() == AbstractInsnNode.METHOD_INSN + && (instr.getOpcode() == Opcodes.INVOKEVIRTUAL + || instr.getOpcode() == Opcodes.INVOKEINTERFACE)) { + MethodInsnNode mCall = (MethodInsnNode) instr; + if(mCall.name.equals("iterator") + // we can't check the return type here because Eclipse makes it Iterator + // but javac respects subtypes (but only here) +// && mCall.desc.equals("()Ljava/util/Iterator;") + ) { + checkForeach(mCall); + } + } + } + } + + private void checkForeach(MethodInsnNode iteratorCall) { + // iterable.iterator(): invoke iterator(), store, jump to test [only Eclipse] + AbstractInsnNode mCallPlus1 = iteratorCall.getNext(); + if(mCallPlus1 == null + || mCallPlus1.getType() != AbstractInsnNode.VAR_INSN + || mCallPlus1.getOpcode() != Opcodes.ASTORE) + return; + VarInsnNode mCallStore = (VarInsnNode) mCallPlus1; + int iteratorVarIndex = mCallStore.var; + AbstractInsnNode mCallPlus2 = mCallPlus1.getNext(); + if(mCallPlus2 == null) + return; + + boolean testBeforeNext; + AbstractInsnNode testInstr; + if(mCallPlus2.getType() == AbstractInsnNode.JUMP_INSN + && mCallPlus2.getOpcode() == Opcodes.GOTO){ + testBeforeNext = true; + // jump to the hasNext() test: load, invoke hasNext(), ifne to body + JumpInsnNode jumpToTest = (JumpInsnNode) mCallPlus2; + testInstr = getJumpTarget(jumpToTest.label); + }else{ + // continue hasNext() test: label, load, invoke hasNext(), ifeq to end + testBeforeNext = false; + if(mCallPlus2.getType() != AbstractInsnNode.LABEL) + return; + testInstr = mCallPlus2.getNext(); + } + + if(testInstr == null + || testInstr.getType() != AbstractInsnNode.VAR_INSN + || testInstr.getOpcode() != Opcodes.ALOAD) + return; + VarInsnNode testLoad = (VarInsnNode) testInstr; + if(testLoad.var != iteratorVarIndex) + return; + AbstractInsnNode testLoadPlus1 = testLoad.getNext(); + if(testLoadPlus1 == null + || testLoadPlus1.getType() != AbstractInsnNode.METHOD_INSN + || testLoadPlus1.getOpcode() != Opcodes.INVOKEINTERFACE) + return; + MethodInsnNode hasNextCall = (MethodInsnNode) testLoadPlus1; + if(!hasNextCall.name.equals("hasNext") + || !hasNextCall.owner.equals("java/util/Iterator") + || !hasNextCall.desc.equals("()Z")) + return; + AbstractInsnNode testLoadPlus2 = hasNextCall.getNext(); + if(testLoadPlus2 == null + || testLoadPlus2.getType() != AbstractInsnNode.JUMP_INSN) + return; + if(testBeforeNext && testLoadPlus2.getOpcode() != Opcodes.IFNE) + return; + if(!testBeforeNext && testLoadPlus2.getOpcode() != Opcodes.IFEQ) + return; + + // Now check body: load, invoke next() + JumpInsnNode jumpToBody = (JumpInsnNode) testLoadPlus2; + AbstractInsnNode bodyInstr = testBeforeNext ? getJumpTarget(jumpToBody.label) : jumpToBody.getNext(); + if(bodyInstr == null + || bodyInstr.getType() != AbstractInsnNode.VAR_INSN + || bodyInstr.getOpcode() != Opcodes.ALOAD) + return; + VarInsnNode bodyLoad = (VarInsnNode) bodyInstr; + if(bodyLoad.var != iteratorVarIndex) + return; + AbstractInsnNode bodyLoadPlus1 = bodyLoad.getNext(); + if(bodyLoadPlus1 == null + || bodyLoadPlus1.getType() != AbstractInsnNode.METHOD_INSN + || bodyLoadPlus1.getOpcode() != Opcodes.INVOKEINTERFACE) + return; + MethodInsnNode nextCall = (MethodInsnNode) bodyLoadPlus1; + if(!nextCall.name.equals("next") + || !nextCall.owner.equals("java/util/Iterator") + || !nextCall.desc.equals("()Ljava/lang/Object;")) + return; + + MethodDatabase.ClassEntry iterableClassEntry = db.getOrLoadClassEntry(iteratorCall.owner); + if(iterableClassEntry == null) + return; + if(!iterableClassEntry.implementsInterface("java/lang/Iterable", db)) + return; + MethodDatabase.ClassEntry methodOwnerClass = iterableClassEntry.getClassImplementingMethod("iterator()", db); + // iteratorType contains the "L...;" parts + String iteratorType = methodOwnerClass.getReturnType("iterator()"); + if(iteratorType == null || iteratorType.equals("Ljava/util/Iterator;")) + return; + + MethodDatabase.ClassEntry iteratorClass = + db.getOrLoadClassEntry(iteratorType.substring(1, iteratorType.length()-1)); + if(iteratorClass == null) + return; + MethodDatabase.ClassEntry nextOwnerClass = iteratorClass.getClassImplementingMethod("next()", db); + if(nextOwnerClass == null) + return; + String nextMethodOwner = nextOwnerClass.getName(); + boolean nextMethodInterface = nextOwnerClass.isInterface(); + + MethodDatabase.ClassEntry hasNextOwnerClass = iteratorClass.getClassImplementingMethod("hasNext()", db); + if(hasNextOwnerClass == null) + return; + String hasNextMethodOwner = hasNextOwnerClass.getName(); + boolean hasNextMethodInterface = hasNextOwnerClass.isInterface(); + + iteratorCall.desc = "()"+iteratorType; + hasNextCall.owner = hasNextMethodOwner; + hasNextCall.setOpcode(hasNextMethodInterface ? Opcodes.INVOKEINTERFACE : Opcodes.INVOKEVIRTUAL); + hasNextCall.itf = hasNextMethodInterface; + nextCall.owner = nextMethodOwner; + nextCall.setOpcode(nextMethodInterface ? Opcodes.INVOKEINTERFACE : Opcodes.INVOKEVIRTUAL); + nextCall.itf = nextMethodInterface; + } + + private AbstractInsnNode getJumpTarget(LabelNode label) { + AbstractInsnNode next = label.getNext(); + while(next.getType() == AbstractInsnNode.FRAME + || next.getType() == AbstractInsnNode.LINE) { + next = next.getNext(); + } + return next; + } + private void collectCallsites() { if (suspCallsBcis == null) { suspCallsBcis = new int[8]; diff --git a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/MethodDatabase.java b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/MethodDatabase.java index 9c3c65c7a5..c14b2e9688 100644 --- a/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/MethodDatabase.java +++ b/quasar-core/src/main/java/co/paralleluniverse/fibers/instrument/MethodDatabase.java @@ -47,12 +47,8 @@ import java.io.IOException; import java.io.InputStream; import java.lang.ref.WeakReference; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.NavigableMap; -import java.util.TreeMap; +import java.util.*; + import org.objectweb.asm.ClassReader; import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.Opcodes; @@ -239,7 +235,7 @@ public synchronized ClassEntry getClassEntry(String className) { public synchronized ClassEntry getOrCreateClassEntry(String className, String superType) { ClassEntry ce = classes.get(className); if (ce == null) { - ce = new ClassEntry(superType); + ce = new ClassEntry(className, superType); classes.put(className, ce); } return ce; @@ -454,30 +450,40 @@ public static boolean isProblematicClass(String className) { || className.startsWith("org/apache/log4j/"); } - private static final ClassEntry CLASS_NOT_FOUND = new ClassEntry(""); + private static final ClassEntry CLASS_NOT_FOUND = new ClassEntry("", ""); public enum SuspendableType { NON_SUSPENDABLE, SUSPENDABLE_SUPER, SUSPENDABLE }; public static final class ClassEntry { - private final HashMap methods; + public final HashMap methods; + public final HashSet bridges; private String sourceName; private String sourceDebugInfo; private boolean isInterface; private String[] interfaces; private final String superName; + private final String name; private boolean instrumented; private volatile boolean requiresInstrumentation; - public ClassEntry(String superName) { + public ClassEntry(String name, String superName) { + this.name = name; this.superName = superName; this.methods = new HashMap<>(); + this.bridges = new HashSet<>(); + } + + public String getName() { + return name; } - public void set(String name, String desc, SuspendableType suspendable) { + public void set(String name, String desc, SuspendableType suspendable, boolean bridge) { String nameAndDesc = key(name, desc); methods.put(nameAndDesc, suspendable); + if(bridge) + bridges.add(nameAndDesc); } public String getSourceName() { @@ -568,6 +574,61 @@ public boolean isInstrumented() { public void setInstrumented(boolean instrumented) { this.instrumented = instrumented; } + + public boolean implementsInterface(String name, MethodDatabase db) { + for(String interf : interfaces){ + if(interf.equals(name)) + return true; + } + if(superName != null){ + ClassEntry superClass = db.getOrLoadClassEntry(superName); + if(superClass != null && superClass.implementsInterface(name, db)) + return true; + } + for(String interf : interfaces){ + ClassEntry superClass = db.getOrLoadClassEntry(interf); + if(superClass != null && superClass.implementsInterface(name, db)) + return true; + } + return false; + } + + public ClassEntry getClassImplementingMethod(String methodNameAndParams, MethodDatabase db) { + for (Map.Entry entry : methods.entrySet()) { + String key = entry.getKey(); + if (key.substring(0, key.indexOf(')')+1).equals(methodNameAndParams) + && !bridges.contains(key)) + return this; + } + if(superName != null){ + ClassEntry superClass = db.getOrLoadClassEntry(superName); + if(superClass != null) { + ClassEntry ret = superClass.getClassImplementingMethod(methodNameAndParams, db); + if (ret != null) + return ret; + } + } + for(String interf : interfaces){ + ClassEntry superClass = db.getOrLoadClassEntry(interf); + if(superClass != null) { + ClassEntry ret = superClass.getClassImplementingMethod(methodNameAndParams, db); + if (ret != null) + return ret; + } + } + return null; + } + + public String getReturnType(String methodNameAndParams) { + for (Map.Entry entry : methods.entrySet()) { + String key = entry.getKey(); + int retIndex = key.indexOf(")")+1; + if (key.substring(0, retIndex).equals(methodNameAndParams) + && !bridges.contains(key)) + return key.substring(retIndex); + } + return null; + } } public static class ExtractSuperClass extends ClassVisitor { From 6579e6b8310c9fd2589756900649357d3b942f6c Mon Sep 17 00:00:00 2001 From: Stephane Epardaud Date: Fri, 24 Nov 2017 14:22:26 +0100 Subject: [PATCH 2/2] Test for #285: suspendable iterators --- .../instrument/SuspendableIteratorTest.java | 203 ++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 quasar-core/src/test/java/co/paralleluniverse/fibers/instrument/SuspendableIteratorTest.java diff --git a/quasar-core/src/test/java/co/paralleluniverse/fibers/instrument/SuspendableIteratorTest.java b/quasar-core/src/test/java/co/paralleluniverse/fibers/instrument/SuspendableIteratorTest.java new file mode 100644 index 0000000000..a72f9155d9 --- /dev/null +++ b/quasar-core/src/test/java/co/paralleluniverse/fibers/instrument/SuspendableIteratorTest.java @@ -0,0 +1,203 @@ +package co.paralleluniverse.fibers.instrument; + +import co.paralleluniverse.fibers.Fiber; +import co.paralleluniverse.fibers.SuspendExecution; +import co.paralleluniverse.fibers.Suspendable; +import co.paralleluniverse.strands.SuspendableCallable; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import static co.paralleluniverse.fibers.TestsHelper.exec; +import static org.junit.Assert.assertEquals; + +public class SuspendableIteratorTest { + + private final List results = new ArrayList<>(); + + static class SuspendableIteratorImpl { + List elems = new ArrayList<>(); + { + elems.add("A"); + elems.add("B"); + elems.add("C"); + } + + @Suspendable + public String next() { + try { + Fiber.park(); + } catch (SuspendExecution e) { + throw new AssertionError(e); + } + return elems.remove(0); + } + + @Suspendable + public boolean hasNext() { + try { + Fiber.park(); + } catch (SuspendExecution e) { + throw new AssertionError(e); + } + return !elems.isEmpty(); + } + }; + + interface SuspendableIteratorInterface extends Iterator { + @Override + @Suspendable + T next(); + + @Override + @Suspendable + boolean hasNext(); + } + + static abstract class SuspendableIteratorClass implements Iterator { + @Override + @Suspendable + public abstract T next(); + + @Override + @Suspendable + public abstract boolean hasNext(); + } + + static class SuspendableListWithIteratorInterface implements Iterable { + @Override + public SuspendableIteratorInterface iterator() { + return new SuspendableIteratorInterface(){ + SuspendableIteratorImpl impl = new SuspendableIteratorImpl(); + + @Override + public void remove() { + } + + @Suspendable + @Override + public String next() { + return impl.next(); + } + + @Suspendable + @Override + public boolean hasNext() { + return impl.hasNext(); + } + }; + } + } + + static class SuspendableListWithIteratorClass implements Iterable { + @Override + public SuspendableIteratorClass iterator() { + return new SuspendableIteratorClass(){ + SuspendableIteratorImpl impl = new SuspendableIteratorImpl(); + + @Override + public void remove() { + } + + @Suspendable + @Override + public String next() { + return impl.next(); + } + + @Suspendable + @Override + public boolean hasNext() { + return impl.hasNext(); + } + }; + } + } + + @Suspendable + private void suspendableListWithIteratorInterface(){ + SuspendableListWithIteratorInterface l = new SuspendableListWithIteratorInterface(); + for(String elem : l){ + results.add(elem); + } + } + + @Test + public void testSuspendableListWithIteratorInterface(){ + Fiber co = new Fiber((String) null, null, (SuspendableCallable) null) { + @Override + protected Object run() throws SuspendExecution, InterruptedException { + suspendableListWithIteratorInterface(); + return null; + } + }; + runTest(co); + } + + @Suspendable + private void suspendableListWithIteratorClass(){ + SuspendableListWithIteratorClass l = new SuspendableListWithIteratorClass(); + for(String elem : l){ + results.add(elem); + } + } + + @Test + public void testSuspendableListWithIteratorClass(){ + Fiber co = new Fiber((String) null, null, (SuspendableCallable) null) { + @Override + protected Object run() throws SuspendExecution, InterruptedException { + suspendableListWithIteratorClass(); + return null; + } + }; + runTest(co); + } + + @Suspendable + private void suspendableListWithIteratorClassMultiple(){ + SuspendableListWithIteratorClass l = new SuspendableListWithIteratorClass(); + for(String elem : l){ + results.add(elem); + for(String elem2 : l){ + results.add(elem2); + } + } + for(String elem : l){ + results.add(elem); + } + } + + @Test + public void testSuspendableListWithIteratorClassMultiple(){ + Fiber co = new Fiber((String) null, null, (SuspendableCallable) null) { + @Override + protected Object run() throws SuspendExecution, InterruptedException { + suspendableListWithIteratorClassMultiple(); + return null; + } + }; + runTest(co, 35, Arrays.asList("A", "A", "B", "C", + "B", "A", "B", "C", + "C", "A", "B", "C", + "A", "B", "C")); + } + + private void runTest(Fiber co) { + runTest(co, 7, Arrays.asList("A", "B", "C")); + } + + private void runTest(Fiber co, int iters, List expected){ + try{ + for(int i=0;i