Skip to main content
aboutsummaryrefslogtreecommitdiffstats
blob: bb720f81ee28410fd6b5b2af2629f7a6f1c53ec0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
/*****************************************************************************
 * Copyright (c) 2015 Christian W. Damus and others.
 * 
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 *   Christian W. Damus - Initial API and implementation
 *   
 *****************************************************************************/

package org.eclipse.papyrus.tests.framework.m2t.xtend.templates

import java.util.Map
import com.google.common.collect.Maps
import com.google.common.base.Joiner
import com.google.common.collect.Iterables
import com.google.common.collect.Ordering
import javax.inject.Singleton
import java.io.File

/**
 * Extensions for managing imports in generated Java files.
 */
 @Singleton
class Importator {
    Map<File, Map<String, String>> importsByFile = Maps.newHashMap
    ThreadLocal<File> tlFile = new ThreadLocal;
    
    def reset() {
        imports.clear
    }
    
    def CharSequence managingImports(File file, () => CharSequence template) {
        var CharSequence result = null
        
        tlFile.set(file)
        try {
            result = importify(template.apply)
        } finally {
            tlFile.remove
            importsByFile.remove(file)
        }
        
        result
    }
    
    private def file() {
        tlFile.get
    }
    
    def imports() {
        if (importsByFile.containsKey(file))
            importsByFile.get(file)
        else
            Maps.newHashMap => [
                importsByFile.put(file, it)
            ]
    }
    
    def String imported(String qualifiedClassName) {
        val simpleName = qualifiedClassName.substring(qualifiedClassName.lastIndexOf('.') + 1)
        val existing = imports.get(simpleName)
        
        if ((simpleName == qualifiedClassName) || ((existing != null) && (existing != qualifiedClassName))) {
            // Cannot import the same name again
            qualifiedClassName
        } else {
            imports.put(simpleName, qualifiedClassName)
            simpleName
        }
    }
    
    def String markImports() {
        "$$$imports$$$"
    }
    
    private def CharSequence importify(CharSequence text) {
        val importsText = Joiner.on(System.getProperty("line.separator")).join(
            Iterables.transform(Ordering.natural.sortedCopy(imports.values), [f|'import ' + f + ';'])
        )
        
        text.toString.replaceFirst("\\$\\$\\$imports\\$\\$\\$", importsText)
    }
}

Back to the top